mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
chore: Update list tables test cases and cleanup test database (#1600)
## Description --- This change updates the list tables(`postgres`, `mysql` and `mssql`) tests with test cases for listing all tables. The test schemas are cleaned at the beginning of the test run to ensure deterministic output for the list_tables tool. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [x] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
@@ -123,6 +123,9 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupMSSQLTables(t, ctx, db)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
@@ -110,6 +110,9 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupMySQLTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
@@ -97,6 +97,9 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create SQL Server connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupMSSQLTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
@@ -87,6 +87,9 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("unable to create MySQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupMySQLTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
@@ -137,6 +137,9 @@ func TestPostgres(t *testing.T) {
|
||||
t.Fatalf("unable to create postgres connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupPostgresTables(t, ctx, pool);
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
@@ -237,7 +240,24 @@ func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth strin
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables all tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables all tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
@@ -334,6 +354,19 @@ func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth strin
|
||||
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
||||
}
|
||||
|
||||
// Checking only the default public schema where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
got = filteredGot
|
||||
}
|
||||
|
||||
sort.SliceStable(got, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
||||
})
|
||||
|
||||
@@ -1184,7 +1184,15 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
|
||||
wantStatusCode int
|
||||
want any
|
||||
isSimple bool
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables for all tables detailed output",
|
||||
requestBody: bytes.NewBufferString(`{"table_names":""}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant, paramTableWant},
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)),
|
||||
@@ -1293,6 +1301,23 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
|
||||
cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }),
|
||||
}
|
||||
|
||||
// Checking only the current database where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []objectDetails
|
||||
if got != nil {
|
||||
for _, item := range got.([]objectDetails) {
|
||||
if item.SchemaName == databaseName {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(filteredGot) == 0 {
|
||||
got = nil
|
||||
} else {
|
||||
got = filteredGot
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
||||
}
|
||||
@@ -1858,7 +1883,24 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
|
||||
requestBody string
|
||||
wantStatusCode int
|
||||
want string
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables for all tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: `{"table_names": ""}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables for all tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: `{"table_names": "", "output_format": "simple"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
@@ -1968,6 +2010,19 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
|
||||
itemMap["object_details"] = detailsMap
|
||||
}
|
||||
|
||||
// Checking only the default dbo schema where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []any
|
||||
for _, item := range got {
|
||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||
if schema, ok := tableMap["schema_name"]; ok && schema == "dbo" {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
got = filteredGot
|
||||
}
|
||||
|
||||
sort.SliceStable(got, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user