tests: add mssql prebuilt tests to cloud sql mssql integration (#1501)

## Description

---
Add mssql's prebuilt tools tests to cloud-sql-mssql integration tests.

Cloud sql mssql's integration test coverage check against the mssql
package since those tools are compatible. Hence, when we add new tools
to mssql, we will have to add those integration tests against cloud sql
mssql as well.
This commit is contained in:
Yuan Teoh
2025-09-25 15:37:09 -07:00
committed by GitHub
parent 4166bf7ab8
commit cb692b5883
4 changed files with 206 additions and 205 deletions

View File

@@ -143,6 +143,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddMSSQLPrebuiltToolConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -167,6 +168,9 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
// Run specific MSSQL tool tests
tests.RunMSSQLListTablesTest(t, tableNameParam, tableNameAuth)
}
// Test connection with different IP type

View File

@@ -350,6 +350,21 @@ func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
return config
}
// AddMSSQLPrebuiltToolConfig gets the tools config for mssql prebuilt tools
func AddMSSQLPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["list_tables"] = map[string]any{
"kind": "mssql-list-tables",
"source": "my-instance",
"description": "Lists tables in the database.",
}
config["tools"] = tools
return config
}
// GetPostgresSQLParamToolInfo returns statements and param for my-tool postgres-sql kind
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)

View File

@@ -15,17 +15,12 @@
package mssql
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"reflect"
"regexp"
"sort"
"strings"
"testing"
"time"
@@ -37,14 +32,13 @@ import (
)
var (
MSSQLSourceKind = "mssql"
MSSQLToolKind = "mssql-sql"
MSSQLListTablesToolKind = "mssql-list-tables"
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
MSSQLHost = os.Getenv("MSSQL_HOST")
MSSQLPort = os.Getenv("MSSQL_PORT")
MSSQLUser = os.Getenv("MSSQL_USER")
MSSQLPass = os.Getenv("MSSQL_PASS")
MSSQLSourceKind = "mssql"
MSSQLToolKind = "mssql-sql"
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
MSSQLHost = os.Getenv("MSSQL_HOST")
MSSQLPort = os.Getenv("MSSQL_PORT")
MSSQLUser = os.Getenv("MSSQL_USER")
MSSQLPass = os.Getenv("MSSQL_PASS")
)
func getMsSQLVars(t *testing.T) map[string]any {
@@ -71,20 +65,6 @@ func getMsSQLVars(t *testing.T) map[string]any {
}
}
func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["list_tables"] = map[string]any{
"kind": MSSQLListTablesToolKind,
"source": "my-instance",
"description": "Lists tables in the database.",
}
config["tools"] = tools
return config
}
// Copied over from mssql.go
func initMSSQLConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
@@ -137,7 +117,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = addPrebuiltToolConfig(t, toolsFile)
toolsFile = tests.AddMSSQLPrebuiltToolConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -164,181 +144,5 @@ func TestMSSQLToolEndpoints(t *testing.T) {
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
// Run specific MSSQL tool tests
runMSSQLListTablesTest(t, tableNameParam, tableNameAuth)
}
func runMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
// TableNameParam columns to construct want.
const paramTableColumns = `[
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false}
]`
// TableNameAuth columns to construct want
const authTableColumns = `[
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false},
{"column_name": "email", "data_type": "VARCHAR(255)", "column_ordinal_position": 3, "is_not_nullable": false}
]`
const (
// Template to construct detailed output want.
detailedObjectTemplate = `{
"schema_name": "dbo",
"object_name": "%[1]s",
"object_details": {
"owner": "dbo",
"triggers": [],
"columns": %[2]s,
"object_name": "%[1]s",
"object_type": "TABLE",
"schema_name": "dbo"
}
}`
// Template to construct simple output want
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"dbo", "object_details":{"name":"%s"}}`
)
// Helper to build json for detailed want
getDetailedWant := func(tableName, columnJSON string) string {
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON)
}
// Helper to build template for simple want
getSimpleWant := func(tableName string) string {
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
}
invokeTcs := []struct {
name string
api string
requestBody string
wantStatusCode int
want string
}{
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
},
{
name: "invoke list_tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
},
{
name: "invoke list_tables with invalid output format",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": "", "output_format": "abcd"}`,
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with malformed table_names parameter",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": 12345, "output_format": "detailed"}`,
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with multiple table names",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
},
{
name: "invoke list_tables with non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": "non_existent_table"}`,
wantStatusCode: http.StatusOK,
want: `null`,
},
{
name: "invoke list_tables with one existing and one non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, respBytes := tests.RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer([]byte(tc.requestBody)), nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatusCode, resp.StatusCode, string(respBytes))
}
if tc.wantStatusCode == http.StatusOK {
var bodyWrapper map[string]json.RawMessage
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
}
resultJSON, ok := bodyWrapper["result"]
if !ok {
t.Fatal("unable to find 'result' in response body")
}
var resultString string
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
if string(resultJSON) == "null" {
resultString = "null"
} else {
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
}
}
var got, want []any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal actual result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal expected want string: %v", err)
}
for _, item := range got {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
detailsStr, ok := itemMap["object_details"].(string)
if !ok {
continue
}
var detailsMap map[string]any
if err := json.Unmarshal([]byte(detailsStr), &detailsMap); err != nil {
t.Fatalf("failed to unmarshal nested object_details string: %v", err)
}
// clean unpredictable fields
delete(detailsMap, "constraints")
delete(detailsMap, "indexes")
itemMap["object_details"] = detailsMap
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})
sort.SliceStable(want, func(i, j int) bool {
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
})
if !reflect.DeepEqual(got, want) {
gotJSON, _ := json.MarshalIndent(got, "", " ")
wantJSON, _ := json.MarshalIndent(want, "", " ")
t.Errorf("Unexpected result:\ngot:\n%s\n\nwant:\n%s", string(gotJSON), string(wantJSON))
}
}
})
}
tests.RunMSSQLListTablesTest(t, tableNameParam, tableNameAuth)
}

View File

@@ -23,6 +23,7 @@ import (
"io"
"net/http"
"reflect"
"sort"
"strings"
"sync"
"testing"
@@ -1807,6 +1808,183 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
}
}
// RunMSSQLListTablesTest run tests againsts the mssql-list-tables tools.
func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
// TableNameParam columns to construct want.
const paramTableColumns = `[
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false}
]`
// TableNameAuth columns to construct want
const authTableColumns = `[
{"column_name": "id", "data_type": "INT", "column_ordinal_position": 1, "is_not_nullable": true},
{"column_name": "name", "data_type": "VARCHAR(255)", "column_ordinal_position": 2, "is_not_nullable": false},
{"column_name": "email", "data_type": "VARCHAR(255)", "column_ordinal_position": 3, "is_not_nullable": false}
]`
const (
// Template to construct detailed output want.
detailedObjectTemplate = `{
"schema_name": "dbo",
"object_name": "%[1]s",
"object_details": {
"owner": "dbo",
"triggers": [],
"columns": %[2]s,
"object_name": "%[1]s",
"object_type": "TABLE",
"schema_name": "dbo"
}
}`
// Template to construct simple output want
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"dbo", "object_details":{"name":"%s"}}`
)
// Helper to build json for detailed want
getDetailedWant := func(tableName, columnJSON string) string {
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON)
}
// Helper to build template for simple want
getSimpleWant := func(tableName string) string {
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
}
invokeTcs := []struct {
name string
api string
requestBody string
wantStatusCode int
want string
}{
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
},
{
name: "invoke list_tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
},
{
name: "invoke list_tables with invalid output format",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": "", "output_format": "abcd"}`,
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with malformed table_names parameter",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": 12345, "output_format": "detailed"}`,
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke list_tables with multiple table names",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
},
{
name: "invoke list_tables with non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": "non_existent_table"}`,
wantStatusCode: http.StatusOK,
want: `null`,
},
{
name: "invoke list_tables with one existing and one non-existent table",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, respBytes := RunRequest(t, http.MethodPost, tc.api, bytes.NewBuffer([]byte(tc.requestBody)), nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("response status code is not %d, got %d: %s", tc.wantStatusCode, resp.StatusCode, string(respBytes))
}
if tc.wantStatusCode == http.StatusOK {
var bodyWrapper map[string]json.RawMessage
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
}
resultJSON, ok := bodyWrapper["result"]
if !ok {
t.Fatal("unable to find 'result' in response body")
}
var resultString string
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
if string(resultJSON) == "null" {
resultString = "null"
} else {
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
}
}
var got, want []any
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal actual result string: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal expected want string: %v", err)
}
for _, item := range got {
itemMap, ok := item.(map[string]any)
if !ok {
continue
}
detailsStr, ok := itemMap["object_details"].(string)
if !ok {
continue
}
var detailsMap map[string]any
if err := json.Unmarshal([]byte(detailsStr), &detailsMap); err != nil {
t.Fatalf("failed to unmarshal nested object_details string: %v", err)
}
// clean unpredictable fields
delete(detailsMap, "constraints")
delete(detailsMap, "indexes")
itemMap["object_details"] = detailsMap
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})
sort.SliceStable(want, func(i, j int) bool {
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
})
if !reflect.DeepEqual(got, want) {
gotJSON, _ := json.MarshalIndent(got, "", " ")
wantJSON, _ := json.MarshalIndent(want, "", " ")
t.Errorf("Unexpected result:\ngot:\n%s\n\nwant:\n%s", string(gotJSON), string(wantJSON))
}
}
})
}
}
// RunRequest is a helper function to send HTTP requests and return the response
func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
// Send request