diff --git a/tests/alloydbpg/alloydb_pg_integration_test.go b/tests/alloydbpg/alloydb_pg_integration_test.go index 8162fc2920..3c2df29102 100644 --- a/tests/alloydbpg/alloydb_pg_integration_test.go +++ b/tests/alloydbpg/alloydb_pg_integration_test.go @@ -129,6 +129,9 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) { t.Fatalf("unable to create AlloyDB connection pool: %s", err) } + // cleanup test environment + tests.CleanupPostgresTables(t, ctx, pool) + // create table name with UUID tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") @@ -175,7 +178,14 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) { tests.RunMCPToolCallMethod(t, failInvocationWant, mcpSelect1Want) tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) + + // Run Postgres prebuilt tool tests + tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, AlloyDBPostgresUser) + tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) tests.RunPostgresListSchemasTest(t, ctx, pool) + tests.RunPostgresListActiveQueriesTest(t, ctx, pool) + tests.RunPostgresListAvailableExtensionsTest(t) + tests.RunPostgresListInstalledExtensionsTest(t) } // Test connection with different IP type diff --git a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go index 3551958ae9..9b1ab94b10 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go @@ -114,6 +114,9 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { t.Fatalf("unable to create Cloud SQL connection pool: %s", err) } + // cleanup test environment + tests.CleanupPostgresTables(t, ctx, pool) + // create table name with UUID tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") @@ -159,7 +162,14 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) + + // Run Postgres prebuilt tool tests + tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, CloudSQLPostgresUser) + tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) tests.RunPostgresListSchemasTest(t, ctx, pool) + tests.RunPostgresListActiveQueriesTest(t, ctx, pool) + tests.RunPostgresListAvailableExtensionsTest(t) + tests.RunPostgresListInstalledExtensionsTest(t) } // Test connection with different IP type diff --git a/tests/common.go b/tests/common.go index 498218e5c0..01208aa192 100644 --- a/tests/common.go +++ b/tests/common.go @@ -33,10 +33,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" ) -var ( - PostgresListSchemasToolKind = "postgres-list-schemas" -) - // GetToolsConfig returns a mock tools config file func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement string) map[string]any { // Write config into a file and pass it to command @@ -195,10 +191,46 @@ func AddExecuteSqlConfig(t *testing.T, config map[string]any, toolKind string) m } func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]any { + var ( + PostgresListSchemasToolKind = "postgres-list-schemas" + PostgresListTablesToolKind = "postgres-list-tables" + PostgresListActiveQueriesToolKind = "postgres-list-active-queries" + PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions" + PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions" + PostgresListViewsToolKind = "postgres-list-views" + ) + tools, ok := config["tools"].(map[string]any) if !ok { t.Fatalf("unable to get tools from config") } + tools["list_tables"] = map[string]any{ + "kind": PostgresListTablesToolKind, + "source": "my-instance", + "description": "Lists tables in the database.", + } + tools["list_active_queries"] = map[string]any{ + "kind": PostgresListActiveQueriesToolKind, + "source": "my-instance", + "description": "Lists active queries in the database.", + } + + tools["list_installed_extensions"] = map[string]any{ + "kind": PostgresListInstalledExtensionsToolKind, + "source": "my-instance", + "description": "Lists installed extensions in the database.", + } + + tools["list_available_extensions"] = map[string]any{ + "kind": PostgresListAvailableExtensionsToolKind, + "source": "my-instance", + "description": "Lists available extensions in the database.", + } + + tools["list_views"] = map[string]any{ + "kind": PostgresListViewsToolKind, + "source": "my-instance", + } tools["list_schemas"] = map[string]any{ "kind": PostgresListSchemasToolKind, "source": "my-instance", diff --git a/tests/postgres/postgres_integration_test.go b/tests/postgres/postgres_integration_test.go index 3c6f79308e..db7de9d27e 100644 --- a/tests/postgres/postgres_integration_test.go +++ b/tests/postgres/postgres_integration_test.go @@ -15,23 +15,15 @@ package postgres import ( - "bytes" "context" - "encoding/json" "fmt" - "io" - "net/http" "net/url" "os" - "reflect" "regexp" - "sort" "strings" - "sync" "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/tests" @@ -39,18 +31,13 @@ import ( ) var ( - PostgresSourceKind = "postgres" - PostgresToolKind = "postgres-sql" - PostgresListTablesToolKind = "postgres-list-tables" - PostgresListActiveQueriesToolKind = "postgres-list-active-queries" - PostgresListInstalledExtensionsToolKind = "postgres-list-installed-extensions" - PostgresListAvailableExtensionsToolKind = "postgres-list-available-extensions" - PostgresListViewsToolKind = "postgres-list-views" - PostgresDatabase = os.Getenv("POSTGRES_DATABASE") - PostgresHost = os.Getenv("POSTGRES_HOST") - PostgresPort = os.Getenv("POSTGRES_PORT") - PostgresUser = os.Getenv("POSTGRES_USER") - PostgresPass = os.Getenv("POSTGRES_PASS") + PostgresSourceKind = "postgres" + PostgresToolKind = "postgres-sql" + PostgresDatabase = os.Getenv("POSTGRES_DATABASE") + PostgresHost = os.Getenv("POSTGRES_HOST") + PostgresPort = os.Getenv("POSTGRES_PORT") + PostgresUser = os.Getenv("POSTGRES_USER") + PostgresPass = os.Getenv("POSTGRES_PASS") ) func getPostgresVars(t *testing.T) map[string]any { @@ -77,43 +64,6 @@ func getPostgresVars(t *testing.T) map[string]any { } } -func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any { - tools, ok := config["tools"].(map[string]any) - if !ok { - t.Fatalf("unable to get tools from config") - } - tools["list_tables"] = map[string]any{ - "kind": PostgresListTablesToolKind, - "source": "my-instance", - "description": "Lists tables in the database.", - } - tools["list_active_queries"] = map[string]any{ - "kind": PostgresListActiveQueriesToolKind, - "source": "my-instance", - "description": "Lists active queries in the database.", - } - - tools["list_installed_extensions"] = map[string]any{ - "kind": PostgresListInstalledExtensionsToolKind, - "source": "my-instance", - "description": "Lists installed extensions in the database.", - } - - tools["list_available_extensions"] = map[string]any{ - "kind": PostgresListAvailableExtensionsToolKind, - "source": "my-instance", - "description": "Lists available extensions in the database.", - } - - tools["list_views"] = map[string]any{ - "kind": PostgresListViewsToolKind, - "source": "my-instance", - } - - config["tools"] = tools - return config -} - // Copied over from postgres.go func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) { // urlExample := "postgres:dd//username:password@localhost:5432/database_name" @@ -166,8 +116,6 @@ func TestPostgres(t *testing.T) { toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") - - toolsFile = addPrebuiltToolConfig(t, toolsFile) toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) @@ -194,506 +142,11 @@ func TestPostgres(t *testing.T) { tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) - // Run specific Postgres tool tests - runPostgresListTablesTest(t, tableNameParam, tableNameAuth) - runPostgresListViewsTest(t, ctx, pool, tableNameParam) + // Run Postgres prebuilt tool tests + tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, PostgresUser) + tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) tests.RunPostgresListSchemasTest(t, ctx, pool) - runPostgresListActiveQueriesTest(t, ctx, pool) - runPostgresListAvailableExtensionsTest(t) - runPostgresListInstalledExtensionsTest(t) -} - -func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) { - // TableNameParam columns to construct want - paramTableColumns := fmt.Sprintf(`[ - {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, - {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null} - ]`, tableNameParam) - - // TableNameAuth columns to construct want - authTableColumns := fmt.Sprintf(`[ - {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, - {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}, - {"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null} - ]`, tableNameAuth) - - const ( - // Template to construct detailed output want - detailedObjectTemplate = `{ - "object_name": "%[1]s", "schema_name": "public", - "object_details": { - "owner": "%[3]s", "comment": null, - "indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}], - "triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public", - "constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}] - } - }` - - // Template to construct simple output want - simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}` - ) - - // Helper to build json for detailed want - getDetailedWant := func(tableName, columnJSON string) string { - return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, PostgresUser) - } - - // Helper to build template for simple want - getSimpleWant := func(tableName string) string { - return fmt.Sprintf(simpleObjectTemplate, tableName, tableName) - } - - invokeTcs := []struct { - name string - api string - requestBody io.Reader - wantStatusCode int - want string - isAllTables bool - }{ - { - name: "invoke list_tables all tables detailed output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), - isAllTables: true, - }, - { - name: "invoke list_tables all tables simple output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)), - isAllTables: true, - }, - { - name: "invoke list_tables detailed output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)), - }, - { - name: "invoke list_tables simple output", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)), - }, - { - name: "invoke list_tables with invalid output format", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)), - wantStatusCode: http.StatusBadRequest, - }, - { - name: "invoke list_tables with malformed table_names parameter", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)), - wantStatusCode: http.StatusBadRequest, - }, - { - name: "invoke list_tables with multiple table names", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), - }, - { - name: "invoke list_tables with non-existent table", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)), - wantStatusCode: http.StatusOK, - want: `null`, - }, - { - name: "invoke list_tables with one existing and one non-existent table", - api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)), - }, - } - for _, tc := range invokeTcs { - t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantStatusCode { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) - } - - if tc.wantStatusCode == http.StatusOK { - var bodyWrapper map[string]json.RawMessage - respBytes, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("error reading response body: %s", err) - } - - if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { - t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) - } - - resultJSON, ok := bodyWrapper["result"] - if !ok { - t.Fatal("unable to find 'result' in response body") - } - - var resultString string - if err := json.Unmarshal(resultJSON, &resultString); err != nil { - t.Fatalf("'result' is not a JSON-encoded string: %s", err) - } - - var got, want []any - - if err := json.Unmarshal([]byte(resultString), &got); err != nil { - t.Fatalf("failed to unmarshal actual result string: %v", err) - } - if err := json.Unmarshal([]byte(tc.want), &want); err != nil { - t.Fatalf("failed to unmarshal expected want string: %v", err) - } - - // Checking only the default public schema where the test tables are created to avoid brittle tests. - if tc.isAllTables { - var filteredGot []any - for _, item := range got { - if tableMap, ok := item.(map[string]interface{}); ok { - if schema, ok := tableMap["schema_name"]; ok && schema == "public" { - filteredGot = append(filteredGot, item) - } - } - } - got = filteredGot - } - - sort.SliceStable(got, func(i, j int) bool { - return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j]) - }) - sort.SliceStable(want, func(i, j int) bool { - return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) - }) - - if !reflect.DeepEqual(got, want) { - t.Errorf("Unexpected result: got %#v, want: %#v", got, want) - } - } - }) - } -} - -func runPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { - type queryListDetails struct { - ProcessId any `json:"pid"` - User string `json:"user"` - Datname string `json:"datname"` - ApplicationName string `json:"application_name"` - ClientAddress string `json:"client_addr"` - State string `json:"state"` - WaitEventType string `json:"wait_event_type"` - WaitEvent string `json:"wait_event"` - BackendStart any `json:"backend_start"` - TransactionStart any `json:"xact_start"` - QueryStart any `json:"query_start"` - QueryDuration any `json:"query_duration"` - Query string `json:"query"` - } - - singleQueryWanted := queryListDetails{ - ProcessId: any(nil), - User: "", - Datname: "", - ApplicationName: "", - ClientAddress: "", - State: "", - WaitEventType: "", - WaitEvent: "", - BackendStart: any(nil), - TransactionStart: any(nil), - QueryStart: any(nil), - QueryDuration: any(nil), - Query: "SELECT pg_sleep(10);", - } - - invokeTcs := []struct { - name string - requestBody io.Reader - clientSleepSecs int - waitSecsBeforeCheck int - wantStatusCode int - want any - }{ - // exclude background monitoring apps such as "wal_uploader" - { - name: "invoke list_active_queries when the system is idle", - requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`), - clientSleepSecs: 0, - waitSecsBeforeCheck: 0, - wantStatusCode: http.StatusOK, - want: []queryListDetails(nil), - }, - { - name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold", - requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`), - clientSleepSecs: 1, - waitSecsBeforeCheck: 1, - wantStatusCode: http.StatusOK, - want: []queryListDetails(nil), - }, - { - name: "invoke list_active_queries when 1 ongoing query should show up", - requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`), - clientSleepSecs: 10, - waitSecsBeforeCheck: 5, - wantStatusCode: http.StatusOK, - want: []queryListDetails{singleQueryWanted}, - }, - } - - var wg sync.WaitGroup - for _, tc := range invokeTcs { - t.Run(tc.name, func(t *testing.T) { - if tc.clientSleepSecs > 0 { - wg.Add(1) - - go func() { - defer wg.Done() - - err := pool.Ping(ctx) - if err != nil { - t.Errorf("unable to connect to test database: %s", err) - return - } - _, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs)) - if err != nil { - t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err) - } - }() - } - - if tc.waitSecsBeforeCheck > 0 { - time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) - } - - const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) - } - if tc.wantStatusCode != http.StatusOK { - return - } - - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } - - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) - } - - var got any - var details []queryListDetails - if err := json.Unmarshal([]byte(resultString), &details); err != nil { - t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) - } - got = details - - if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool { - return a.Query == b.Query - })); diff != "" { - t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) - } - }) - } - wg.Wait() -} - -func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() { - createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName) - _, err := pool.Exec(ctx, createView) - if err != nil { - t.Fatalf("failed to create view: %v", err) - } - return func() { - dropView := fmt.Sprintf("DROP VIEW %s", viewName) - _, err := pool.Exec(ctx, dropView) - if err != nil { - t.Fatalf("failed to drop view: %v", err) - } - } -} - -func runPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) { - viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "") - dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName) - defer dropViewfunc1() - - invokeTcs := []struct { - name string - requestBody io.Reader - wantStatusCode int - want string - }{ - { - name: "invoke list_views with newly created view", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))), - wantStatusCode: http.StatusOK, - want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1), - }, - { - name: "invoke list_views with non-existent_view", - requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)), - wantStatusCode: http.StatusOK, - want: `null`, - }, - } - for _, tc := range invokeTcs { - t.Run(tc.name, func(t *testing.T) { - const api = "http://127.0.0.1:5000/api/tool/list_views/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) - } - if tc.wantStatusCode != http.StatusOK { - return - } - - var bodyWrapper struct { - Result json.RawMessage `json:"result"` - } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { - t.Fatalf("error decoding response wrapper: %v", err) - } - - var resultString string - if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { - resultString = string(bodyWrapper.Result) - } - - var got, want any - if err := json.Unmarshal([]byte(resultString), &got); err != nil { - t.Fatalf("failed to unmarshal nested result string: %v", err) - } - if err := json.Unmarshal([]byte(tc.want), &want); err != nil { - t.Fatalf("failed to unmarshal want string: %v", err) - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("Unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func runPostgresListAvailableExtensionsTest(t *testing.T) { - invokeTcs := []struct { - name string - api string - requestBody io.Reader - wantStatusCode int - }{ - { - name: "invoke list_available_extensions output", - api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke", - wantStatusCode: http.StatusOK, - requestBody: bytes.NewBuffer([]byte(`{}`)), - }, - } - for _, tc := range invokeTcs { - t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantStatusCode { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) - } - - // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. - // Adding the check will make the test flaky. - }) - } -} - -func runPostgresListInstalledExtensionsTest(t *testing.T) { - invokeTcs := []struct { - name string - api string - requestBody io.Reader - wantStatusCode int - }{ - { - name: "invoke list_installed_extensions output", - api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke", - wantStatusCode: http.StatusOK, - requestBody: bytes.NewBuffer([]byte(`{}`)), - }, - } - for _, tc := range invokeTcs { - t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - - if resp.StatusCode != tc.wantStatusCode { - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) - } - - // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. - // Adding the check will make the test flaky. - }) - } + tests.RunPostgresListActiveQueriesTest(t, ctx, pool) + tests.RunPostgresListAvailableExtensionsTest(t) + tests.RunPostgresListInstalledExtensionsTest(t) } diff --git a/tests/tool.go b/tests/tool.go index d8d0d42e82..0511fd5208 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -149,31 +149,17 @@ func RunToolInvokeSimpleTest(t *testing.T, name string, simpleWant string) { for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - for k, v := range tc.requestHeader { - req.Header.Add(k, v) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) + err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } @@ -212,31 +198,17 @@ func RunToolInvokeParametersTest(t *testing.T, name string, params []byte, simpl for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - for k, v := range tc.requestHeader { - req.Header.Add(k, v) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) + err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } @@ -447,25 +419,11 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp return } // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - // Add headers - for k, v := range tc.requestHeader { - req.Header.Add(k, v) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) // Check status code if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } // skip response body check @@ -475,7 +433,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp // Check response body var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) + err = json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body: %s", err) } @@ -620,32 +578,17 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options insertAllow := !tc.insert || (tc.insert && configs.supportInsert) if ddlAllow && insertAllow { // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - for k, v := range tc.requestHeader { - req.Header.Add(k, v) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) + err := json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } @@ -769,31 +712,17 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement, select1Want for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { // Send Tool invocation request - req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %s", err) - } - req.Header.Add("Content-type", "application/json") - for k, v := range tc.requestHeader { - req.Header.Add(k, v) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %s", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, tc.requestHeader) if resp.StatusCode != http.StatusOK { if tc.isErr { return } - bodyBytes, _ := io.ReadAll(resp.Body) - t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } // Check response body var body map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&body) + err = json.Unmarshal(respBody, &body) if err != nil { t.Fatalf("error parsing response body") } @@ -1157,6 +1086,257 @@ func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool, } } +func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user string) { + // TableNameParam columns to construct want + paramTableColumns := fmt.Sprintf(`[ + {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, + {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null} + ]`, tableNameParam) + + // TableNameAuth columns to construct want + authTableColumns := fmt.Sprintf(`[ + {"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null}, + {"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}, + {"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null} + ]`, tableNameAuth) + + const ( + // Template to construct detailed output want + detailedObjectTemplate = `{ + "object_name": "%[1]s", "schema_name": "public", + "object_details": { + "owner": "%[3]s", "comment": null, + "indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}], + "triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public", + "constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}] + } + }` + + // Template to construct simple output want + simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}` + ) + + // Helper to build json for detailed want + getDetailedWant := func(tableName, columnJSON string) string { + return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, user) + } + + // Helper to build template for simple want + getSimpleWant := func(tableName string) string { + return fmt.Sprintf(simpleObjectTemplate, tableName, tableName) + } + + invokeTcs := []struct { + name string + api string + requestBody io.Reader + wantStatusCode int + want string + isAllTables bool + }{ + { + name: "invoke list_tables all tables detailed output", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), + isAllTables: true, + }, + { + name: "invoke list_tables all tables simple output", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)), + isAllTables: true, + }, + { + name: "invoke list_tables detailed output", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth))), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)), + }, + { + name: "invoke list_tables simple output", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)), + }, + { + name: "invoke list_tables with invalid output format", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "invoke list_tables with malformed table_names parameter", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)), + wantStatusCode: http.StatusBadRequest, + }, + { + name: "invoke list_tables with multiple table names", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)), + }, + { + name: "invoke list_tables with non-existent table", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)), + wantStatusCode: http.StatusOK, + want: `null`, + }, + { + name: "invoke list_tables with one existing and one non-existent table", + api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)), + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, respBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBytes)) + } + + if tc.wantStatusCode == http.StatusOK { + var bodyWrapper map[string]json.RawMessage + + if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { + t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) + } + + resultJSON, ok := bodyWrapper["result"] + if !ok { + t.Fatal("unable to find 'result' in response body") + } + + var resultString string + if err := json.Unmarshal(resultJSON, &resultString); err != nil { + t.Fatalf("'result' is not a JSON-encoded string: %s", err) + } + + var got, want []any + + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal actual result string: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal expected want string: %v", err) + } + + // Checking only the default public schema where the test tables are created to avoid brittle tests. + if tc.isAllTables { + var filteredGot []any + for _, item := range got { + if tableMap, ok := item.(map[string]interface{}); ok { + if schema, ok := tableMap["schema_name"]; ok && schema == "public" { + filteredGot = append(filteredGot, item) + } + } + } + got = filteredGot + } + + sort.SliceStable(got, func(i, j int) bool { + return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j]) + }) + sort.SliceStable(want, func(i, j int) bool { + return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j]) + }) + + if !reflect.DeepEqual(got, want) { + t.Errorf("Unexpected result: got %#v, want: %#v", got, want) + } + } + }) + } +} + +func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() { + createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName) + _, err := pool.Exec(ctx, createView) + if err != nil { + t.Fatalf("failed to create view: %v", err) + } + return func() { + dropView := fmt.Sprintf("DROP VIEW %s", viewName) + _, err := pool.Exec(ctx, dropView) + if err != nil { + t.Fatalf("failed to drop view: %v", err) + } + } +} + +func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) { + viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "") + dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName) + defer dropViewfunc1() + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want string + }{ + { + name: "invoke list_views with newly created view", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))), + wantStatusCode: http.StatusOK, + want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1), + }, + { + name: "invoke list_views with non-existent_view", + requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)), + wantStatusCode: http.StatusOK, + want: `null`, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_views/invoke" + resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(body, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got, want any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal nested result string: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal want string: %v", err) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Unexpected result (-want +got):\n%s", diff) + } + }) + } +} + func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { schemaName := "test_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "") cleanup := setupPostgresSchemas(t, ctx, pool, schemaName) @@ -1186,20 +1366,9 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_schemas/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return @@ -1208,7 +1377,7 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool var bodyWrapper struct { Result json.RawMessage `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } @@ -1229,6 +1398,191 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool } } +func RunPostgresListActiveQueriesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + type queryListDetails struct { + ProcessId any `json:"pid"` + User string `json:"user"` + Datname string `json:"datname"` + ApplicationName string `json:"application_name"` + ClientAddress string `json:"client_addr"` + State string `json:"state"` + WaitEventType string `json:"wait_event_type"` + WaitEvent string `json:"wait_event"` + BackendStart any `json:"backend_start"` + TransactionStart any `json:"xact_start"` + QueryStart any `json:"query_start"` + QueryDuration any `json:"query_duration"` + Query string `json:"query"` + } + + singleQueryWanted := queryListDetails{ + ProcessId: any(nil), + User: "", + Datname: "", + ApplicationName: "", + ClientAddress: "", + State: "", + WaitEventType: "", + WaitEvent: "", + BackendStart: any(nil), + TransactionStart: any(nil), + QueryStart: any(nil), + QueryDuration: any(nil), + Query: "SELECT pg_sleep(10);", + } + + invokeTcs := []struct { + name string + requestBody io.Reader + clientSleepSecs int + waitSecsBeforeCheck int + wantStatusCode int + want any + }{ + // exclude background monitoring apps such as "wal_uploader" + { + name: "invoke list_active_queries when the system is idle", + requestBody: bytes.NewBufferString(`{"exclude_application_names": "wal_uploader"}`), + clientSleepSecs: 0, + waitSecsBeforeCheck: 0, + wantStatusCode: http.StatusOK, + want: []queryListDetails(nil), + }, + { + name: "invoke list_active_queries when there is 1 ongoing but lower than the threshold", + requestBody: bytes.NewBufferString(`{"min_duration": "100 seconds", "exclude_application_names": "wal_uploader"}`), + clientSleepSecs: 1, + waitSecsBeforeCheck: 1, + wantStatusCode: http.StatusOK, + want: []queryListDetails(nil), + }, + { + name: "invoke list_active_queries when 1 ongoing query should show up", + requestBody: bytes.NewBufferString(`{"min_duration": "1 seconds", "exclude_application_names": "wal_uploader"}`), + clientSleepSecs: 10, + waitSecsBeforeCheck: 5, + wantStatusCode: http.StatusOK, + want: []queryListDetails{singleQueryWanted}, + }, + } + + var wg sync.WaitGroup + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + if tc.clientSleepSecs > 0 { + wg.Add(1) + + go func() { + defer wg.Done() + + err := pool.Ping(ctx) + if err != nil { + t.Errorf("unable to connect to test database: %s", err) + return + } + _, err = pool.Exec(ctx, fmt.Sprintf("SELECT pg_sleep(%d);", tc.clientSleepSecs)) + if err != nil { + t.Errorf("Executing 'SELECT pg_sleep' failed: %s", err) + } + }() + } + + if tc.waitSecsBeforeCheck > 0 { + time.Sleep(time.Duration(tc.waitSecsBeforeCheck) * time.Second) + } + + const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got any + var details []queryListDetails + if err := json.Unmarshal([]byte(resultString), &details); err != nil { + t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) + } + got = details + + if diff := cmp.Diff(tc.want, got, cmp.Comparer(func(a, b queryListDetails) bool { + return a.Query == b.Query + })); diff != "" { + t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) + } + }) + } + wg.Wait() +} + +func RunPostgresListAvailableExtensionsTest(t *testing.T) { + invokeTcs := []struct { + name string + api string + requestBody io.Reader + wantStatusCode int + }{ + { + name: "invoke list_available_extensions output", + api: "http://127.0.0.1:5000/api/tool/list_available_extensions/invoke", + wantStatusCode: http.StatusOK, + requestBody: bytes.NewBuffer([]byte(`{}`)), + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) + } + + // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. + // Adding the check will make the test flaky. + }) + } +} + +func RunPostgresListInstalledExtensionsTest(t *testing.T) { + invokeTcs := []struct { + name string + api string + requestBody io.Reader + wantStatusCode int + }{ + { + name: "invoke list_installed_extensions output", + api: "http://127.0.0.1:5000/api/tool/list_installed_extensions/invoke", + wantStatusCode: http.StatusOK, + requestBody: bytes.NewBuffer([]byte(`{}`)), + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, bodyBytes := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. + // Adding the check will make the test flaky. + }) + } +} + // RunMySQLListTablesTest run tests against the mysql-list-tables tool func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) { type tableInfo struct { @@ -1335,20 +1689,8 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - + resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) } if tc.wantStatusCode != http.StatusOK { @@ -1358,7 +1700,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam var bodyWrapper struct { Result json.RawMessage `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { + if err := json.Unmarshal(body, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } @@ -1532,21 +1874,9 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql. } const api = "http://127.0.0.1:5000/api/tool/list_active_queries/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return @@ -1555,7 +1885,7 @@ func RunMySQLListActiveQueriesTest(t *testing.T, ctx context.Context, pool *sql. var bodyWrapper struct { Result json.RawMessage `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } @@ -1765,21 +2095,9 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p } const api = "http://127.0.0.1:5000/api/tool/list_tables_missing_unique_indexes/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return @@ -1788,7 +2106,7 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p var bodyWrapper struct { Result json.RawMessage `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) } @@ -1892,21 +2210,9 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { const api = "http://127.0.0.1:5000/api/tool/list_table_fragmentation/invoke" - req, err := http.NewRequest(http.MethodPost, api, tc.requestBody) - if err != nil { - t.Fatalf("unable to create request: %v", err) - } - req.Header.Add("Content-type", "application/json") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - t.Fatalf("unable to send request: %v", err) - } - defer resp.Body.Close() - + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) if resp.StatusCode != tc.wantStatusCode { - body, _ := io.ReadAll(resp.Body) - t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) } if tc.wantStatusCode != http.StatusOK { return @@ -1915,7 +2221,7 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar var bodyWrapper struct { Result json.RawMessage `json:"result"` } - if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil { + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { t.Fatalf("error decoding response wrapper: %v", err) }