mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
feat(sources/mariadb): add MariaDB source and MySQL tools integration (#1908)
## Description This PR 1. Adds **MariaDB** as a Source - Implementation is similar to **MySQL** source 2. Utilises pre implemented **MySQL** Tools - `mysql-execute-sql` - `mysql-list-active-queries` - `mysql-list-table-fragmentation` - `mysql-list-tables` - `mysql-list-tables-missing-unique-indexes` - `mysql-sql` **Note:** After discussion with @duwenxin99 in issue #1768, I initially assumed MariaDB required new tools due to different metadata structures and system tables. That is true for older MariaDB versions, but current MySQL tooling already works with MariaDB (verified), so a separate tool set was not needed. 3. Adds a source doc for **MariaDB** in docs 4. Adds MariaDB integration tests using the existing MySQL test flow. Note: The test file is based on the MySQL integration test, but `GetMariaDBWants()` and `RunMariDBListTablesTest()` are implemented because MariaDB returns different metadata and list-tables output, so the assertions must be MariaDB-specific. 5. Updates CI Lastly I considered adding a MariaDB-exclusive Galera cluster monitoring tool, but skipped it because it requires a multi-node Galera setup for integration testing and would significantly increase CI complexity with unclear usage demand. ## PR Checklist - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #1712 #1768 --------- Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
This commit is contained in:
@@ -887,6 +887,26 @@ steps:
|
||||
singlestore \
|
||||
singlestore
|
||||
|
||||
- id: "mariadb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "MARIADB_DATABASE=$_MARIADB_DATABASE"
|
||||
- "MARIADB_PORT=$_MARIADB_PORT"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["MARIADB_USER", "MARIADB_PASS", "MARIADB_HOST", "CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
# skip coverage check as it re-uses current MySQL implementation
|
||||
go test ./tests/mariadb
|
||||
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -999,6 +1019,12 @@ availableSecrets:
|
||||
env: SINGLESTORE_PASSWORD
|
||||
- versionName: projects/$PROJECT_ID/secrets/singlestore_host/versions/latest
|
||||
env: SINGLESTORE_HOST
|
||||
- versionName: projects/$PROJECT_ID/secrets/mariadb_user/versions/latest
|
||||
env: MARIADB_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/mariadb_pass/versions/latest
|
||||
env: MARIADB_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/mariadb_host/versions/latest
|
||||
env: MARIADB_HOST
|
||||
- versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest
|
||||
env: MONGODB_URI
|
||||
|
||||
@@ -1061,3 +1087,6 @@ substitutions:
|
||||
_SINGLESTORE_PORT: "3308"
|
||||
_SINGLESTORE_DATABASE: "singlestore"
|
||||
_SINGLESTORE_USER: "root"
|
||||
_MARIADB_PORT: "3307"
|
||||
_MARIADB_DATABASE: test_database
|
||||
|
||||
|
||||
78
docs/en/resources/sources/mariadb.md
Normal file
78
docs/en/resources/sources/mariadb.md
Normal file
@@ -0,0 +1,78 @@
|
||||
---
|
||||
title: "MariaDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
MariaDB is an open-source relational database compatible with MySQL.
|
||||
|
||||
---
|
||||
## About
|
||||
|
||||
MariaDB is a relational database management system derived from MySQL. It
|
||||
implements the MySQL protocol and client libraries and supports modern SQL
|
||||
features with a focus on performance and reliability.
|
||||
|
||||
**Note**: MariaDB is supported using the MySQL source.
|
||||
## Available Tools
|
||||
|
||||
- [`mysql-sql`](../tools/mysql/mysql-sql.md)
|
||||
Execute pre-defined prepared SQL queries in MariaDB.
|
||||
|
||||
- [`mysql-execute-sql`](../tools/mysql/mysql-execute-sql.md)
|
||||
Run parameterized SQL queries in MariaDB.
|
||||
|
||||
- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md)
|
||||
List active queries in MariaDB.
|
||||
|
||||
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
|
||||
List tables in a MariaDB database.
|
||||
|
||||
- [`mysql-list-tables-missing-unique-indexes`](../tools/mysql/mysql-list-tables-missing-unique-indexes.md)
|
||||
List tables in a MariaDB database that do not have primary or unique indices.
|
||||
|
||||
- [`mysql-list-table-fragmentation`](../tools/mysql/mysql-list-table-fragmentation.md)
|
||||
List table fragmentation in MariaDB tables.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
This source only uses standard authentication. You will need to [create a
|
||||
MariaDB user][mariadb-users] to log in to the database.
|
||||
|
||||
[mariadb-users]: https://mariadb.com/kb/en/create-user/
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_mariadb_db:
|
||||
kind: mysql
|
||||
host: 127.0.0.1
|
||||
port: 3306
|
||||
database: my_db
|
||||
user: ${MARIADB_USER}
|
||||
password: ${MARIADB_PASS}
|
||||
# Optional TLS and other driver parameters. For example, enable preferred TLS:
|
||||
# queryParams:
|
||||
# tls: preferred
|
||||
queryTimeout: 30s # Optional: query timeout duration
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variables instead of committing credentials to source files.
|
||||
{{< /notice >}}
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ------------ | :------: | :----------: | ----------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be `mysql`. |
|
||||
| host | string | true | IP address to connect to (e.g. "127.0.0.1"). |
|
||||
| port | string | true | Port to connect to (e.g. "3307"). |
|
||||
| database | string | true | Name of the MariaDB database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the MariaDB user to connect as (e.g. "my-mysql-user"). |
|
||||
| password | string | true | Password of the MariaDB user (e.g. "my-password"). |
|
||||
| queryTimeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, no timeout is applied. |
|
||||
| queryParams | map<string,string> | false | Arbitrary DSN parameters passed to the driver (e.g. `tls: preferred`, `charset: utf8mb4`). Useful for enabling TLS or other connection options. |
|
||||
@@ -881,7 +881,7 @@ func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) {
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
@@ -921,7 +921,7 @@ func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) {
|
||||
// Finds and drops all tables in a postgres database.
|
||||
func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
query := `
|
||||
SELECT table_name FROM information_schema.tables
|
||||
SELECT table_name FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_type = 'BASE TABLE';`
|
||||
|
||||
rows, err := pool.Query(ctx, query)
|
||||
@@ -954,7 +954,7 @@ func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool
|
||||
// Finds and drops all tables in a mysql database.
|
||||
func CleanupMySQLTables(t *testing.T, ctx context.Context, pool *sql.DB) {
|
||||
query := `
|
||||
SELECT table_name FROM information_schema.tables
|
||||
SELECT table_name FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE';`
|
||||
|
||||
rows, err := pool.QueryContext(ctx, query)
|
||||
|
||||
351
tests/mariadb/mariadb_integration_test.go
Normal file
351
tests/mariadb/mariadb_integration_test.go
Normal file
@@ -0,0 +1,351 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
MariaDBSourceKind = "mysql"
|
||||
MariaDBToolKind = "mysql-sql"
|
||||
MariaDBDatabase = os.Getenv("MARIADB_DATABASE")
|
||||
MariaDBHost = os.Getenv("MARIADB_HOST")
|
||||
MariaDBPort = os.Getenv("MARIADB_PORT")
|
||||
MariaDBUser = os.Getenv("MARIADB_USER")
|
||||
MariaDBPass = os.Getenv("MARIADB_PASS")
|
||||
)
|
||||
|
||||
func getMariaDBVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case MariaDBDatabase:
|
||||
t.Fatal("'MARIADB_DATABASE' not set")
|
||||
case MariaDBHost:
|
||||
t.Fatal("'MARIADB_HOST' not set")
|
||||
case MariaDBPort:
|
||||
t.Fatal("'MARIADB_PORT' not set")
|
||||
case MariaDBUser:
|
||||
t.Fatal("'MARIADB_USER' not set")
|
||||
case MariaDBPass:
|
||||
t.Fatal("'MARIADB_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": MariaDBSourceKind,
|
||||
"host": MariaDBHost,
|
||||
"port": MariaDBPort,
|
||||
"database": MariaDBDatabase,
|
||||
"user": MariaDBUser,
|
||||
"password": MariaDBPass,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from mysql.go
|
||||
func initMariaDB(host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
pool, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func TestMySQLToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getMariaDBVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initMariaDB(MariaDBHost, MariaDBPort, MariaDBUser, MariaDBPass, MariaDBDatabase)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create MySQL connection pool: %s", err)
|
||||
}
|
||||
|
||||
// cleanup test environment
|
||||
tests.CleanupMySQLTables(t, ctx, pool)
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMySQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MariaDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MariaDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
toolsFile = tests.AddMySQLPrebuiltToolConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Get configs for tests
|
||||
select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := GetMariaDBWants()
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want, tests.DisableArrayTest())
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run specific MySQL tool tests
|
||||
RunMariDBListTablesTest(t, MariaDBDatabase, tableNameParam, tableNameAuth)
|
||||
tests.RunMySQLListActiveQueriesTest(t, ctx, pool)
|
||||
tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MariaDBDatabase)
|
||||
tests.RunMySQLListTableFragmentationTest(t, MariaDBDatabase, tableNameParam, tableNameAuth)
|
||||
}
|
||||
|
||||
// RunMariDBListTablesTest run tests against the mysql-list-tables tool
|
||||
func RunMariDBListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) {
|
||||
type tableInfo struct {
|
||||
ObjectName string `json:"object_name"`
|
||||
SchemaName string `json:"schema_name"`
|
||||
ObjectDetails string `json:"object_details"`
|
||||
}
|
||||
|
||||
type column struct {
|
||||
DataType string `json:"data_type"`
|
||||
ColumnName string `json:"column_name"`
|
||||
ColumnComment string `json:"column_comment"`
|
||||
ColumnDefault any `json:"column_default"`
|
||||
IsNotNullable bool `json:"is_not_nullable"`
|
||||
OrdinalPosition int `json:"ordinal_position"`
|
||||
}
|
||||
|
||||
type objectDetails struct {
|
||||
Owner any `json:"owner"`
|
||||
Columns []column `json:"columns"`
|
||||
Comment string `json:"comment"`
|
||||
Indexes []any `json:"indexes"`
|
||||
Triggers []any `json:"triggers"`
|
||||
Constraints []any `json:"constraints"`
|
||||
ObjectName string `json:"object_name"`
|
||||
ObjectType string `json:"object_type"`
|
||||
SchemaName string `json:"schema_name"`
|
||||
}
|
||||
|
||||
paramTableWant := objectDetails{
|
||||
ObjectName: tableNameParam,
|
||||
SchemaName: databaseName,
|
||||
ObjectType: "TABLE",
|
||||
Columns: []column{
|
||||
{DataType: "int(11)", ColumnName: "id", IsNotNullable: true, OrdinalPosition: 1, ColumnDefault: nil},
|
||||
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2, ColumnDefault: "NULL"},
|
||||
},
|
||||
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": true, "is_unique": true}},
|
||||
Triggers: []any{},
|
||||
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
||||
}
|
||||
|
||||
authTableWant := objectDetails{
|
||||
ObjectName: tableNameAuth,
|
||||
SchemaName: databaseName,
|
||||
ObjectType: "TABLE",
|
||||
Columns: []column{
|
||||
{DataType: "int(11)", ColumnName: "id", IsNotNullable: true, OrdinalPosition: 1, ColumnDefault: nil},
|
||||
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2, ColumnDefault: "NULL"},
|
||||
{DataType: "varchar(255)", ColumnName: "email", OrdinalPosition: 3, ColumnDefault: "NULL"},
|
||||
},
|
||||
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": true, "is_unique": true}},
|
||||
Triggers: []any{},
|
||||
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want any
|
||||
isSimple bool
|
||||
isAllTables bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables for all tables detailed output",
|
||||
requestBody: bytes.NewBufferString(`{"table_names":""}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant, paramTableWant},
|
||||
isAllTables: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables simple output",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []map[string]any{{"name": tableNameAuth}},
|
||||
isSimple: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with multiple table names",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant, paramTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with one existing and one non-existent table",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke"
|
||||
resp, body := tests.RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got any
|
||||
if tc.isSimple {
|
||||
var tables []tableInfo
|
||||
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
||||
}
|
||||
var details []map[string]any
|
||||
for _, table := range tables {
|
||||
var d map[string]any
|
||||
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
details = append(details, d)
|
||||
}
|
||||
got = details
|
||||
} else {
|
||||
if resultString == "null" {
|
||||
got = nil
|
||||
} else {
|
||||
var tables []tableInfo
|
||||
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
||||
}
|
||||
var details []objectDetails
|
||||
for _, table := range tables {
|
||||
var d objectDetails
|
||||
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
details = append(details, d)
|
||||
}
|
||||
got = details
|
||||
}
|
||||
}
|
||||
|
||||
opts := []cmp.Option{
|
||||
cmpopts.SortSlices(func(a, b objectDetails) bool { return a.ObjectName < b.ObjectName }),
|
||||
cmpopts.SortSlices(func(a, b column) bool { return a.ColumnName < b.ColumnName }),
|
||||
cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }),
|
||||
}
|
||||
|
||||
// Checking only the current database where the test tables are created to avoid brittle tests.
|
||||
if tc.isAllTables {
|
||||
var filteredGot []objectDetails
|
||||
if got != nil {
|
||||
for _, item := range got.([]objectDetails) {
|
||||
if item.SchemaName == databaseName {
|
||||
filteredGot = append(filteredGot, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(filteredGot) == 0 {
|
||||
got = nil
|
||||
} else {
|
||||
got = filteredGot
|
||||
}
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// GetMariaDBWants return the expected wants for mariaDB
|
||||
func GetMariaDBWants() (string, string, string, string) {
|
||||
select1Want := `[{"1":1}]`
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MariaDB server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY, name TEXT)"`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
|
||||
return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want
|
||||
}
|
||||
Reference in New Issue
Block a user