mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
fix(embeddingModel): add embedding model to MCP handler (#2310)
- Add embedding model to mcp handlers - Add integration tests
This commit is contained in:
@@ -87,7 +87,7 @@ steps:
|
|||||||
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv:
|
secretEnv:
|
||||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
|
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -134,7 +134,7 @@ steps:
|
|||||||
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
||||||
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
|
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -305,7 +305,7 @@ steps:
|
|||||||
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
||||||
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
|
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -964,6 +964,13 @@ steps:
|
|||||||
|
|
||||||
availableSecrets:
|
availableSecrets:
|
||||||
secretManager:
|
secretManager:
|
||||||
|
# Common secrets
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||||
|
env: CLIENT_ID
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
|
||||||
|
env: API_KEY
|
||||||
|
|
||||||
|
# Resource-specific secrets
|
||||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||||
env: CLOUD_SQL_POSTGRES_USER
|
env: CLOUD_SQL_POSTGRES_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
||||||
@@ -980,8 +987,6 @@ availableSecrets:
|
|||||||
env: POSTGRES_USER
|
env: POSTGRES_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
||||||
env: POSTGRES_PASS
|
env: POSTGRES_PASS
|
||||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
|
||||||
env: CLIENT_ID
|
|
||||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
||||||
env: NEO4J_USER
|
env: NEO4J_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
||||||
|
|||||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semanti search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
|
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
|
|||||||
@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semantic search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||||
|
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test connection with different IP type
|
// Test connection with different IP type
|
||||||
|
|||||||
251
tests/embedding.go
Normal file
251
tests/embedding.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package tests contains end to end tests meant to verify the Toolbox Server
|
||||||
|
// works as expected when executed as a binary.
|
||||||
|
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
var apiKey = os.Getenv("API_KEY")
|
||||||
|
|
||||||
|
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
|
||||||
|
// with configurable tool kind and SQL statements.
|
||||||
|
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
|
||||||
|
config["embeddingModels"] = map[string]any{
|
||||||
|
"gemini_model": map[string]any{
|
||||||
|
"kind": "gemini",
|
||||||
|
"model": "gemini-embedding-001",
|
||||||
|
"apiKey": apiKey,
|
||||||
|
"dimension": 768,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools, ok := config["tools"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unable to get tools from config")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools["insert_docs"] = map[string]any{
|
||||||
|
"kind": toolKind,
|
||||||
|
"source": "my-instance",
|
||||||
|
"description": "Stores content and its vector embedding into the documents table.",
|
||||||
|
"statement": insertStmt,
|
||||||
|
"parameters": []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "content",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content associated with the vector.",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"name": "text_to_embed",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content used to generate the vector.",
|
||||||
|
"embeddedBy": "gemini_model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools["search_docs"] = map[string]any{
|
||||||
|
"kind": toolKind,
|
||||||
|
"source": "my-instance",
|
||||||
|
"description": "Finds the most semantically similar document to the query vector.",
|
||||||
|
"statement": searchStmt,
|
||||||
|
"parameters": []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content to search for.",
|
||||||
|
"embeddedBy": "gemini_model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config["tools"] = tools
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
|
||||||
|
// via both HTTP and MCP endpoints and verifies the output.
|
||||||
|
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
|
||||||
|
// Initialize MCP session once for the MCP test cases
|
||||||
|
sessionId := RunInitialize(t, "2024-11-05")
|
||||||
|
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
api string
|
||||||
|
isMcp bool
|
||||||
|
requestBody interface{}
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "HTTP invoke insert_docs",
|
||||||
|
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
|
||||||
|
isMcp: false,
|
||||||
|
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
|
||||||
|
want: insertWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP invoke search_docs",
|
||||||
|
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
|
||||||
|
isMcp: false,
|
||||||
|
requestBody: `{"query": "fast fox jumping"}`,
|
||||||
|
want: searchWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MCP invoke insert_docs",
|
||||||
|
api: "http://127.0.0.1:5000/mcp",
|
||||||
|
isMcp: true,
|
||||||
|
requestBody: jsonrpc.JSONRPCRequest{
|
||||||
|
Jsonrpc: "2.0",
|
||||||
|
Id: "mcp-insert-docs",
|
||||||
|
Request: jsonrpc.Request{
|
||||||
|
Method: "tools/call",
|
||||||
|
},
|
||||||
|
Params: map[string]any{
|
||||||
|
"name": "insert_docs",
|
||||||
|
"arguments": map[string]any{
|
||||||
|
"content": "The quick brown fox jumps over the lazy dog",
|
||||||
|
"text_to_embed": "The quick brown fox jumps over the lazy dog",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: mcpInsertWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MCP invoke search_docs",
|
||||||
|
api: "http://127.0.0.1:5000/mcp",
|
||||||
|
isMcp: true,
|
||||||
|
requestBody: jsonrpc.JSONRPCRequest{
|
||||||
|
Jsonrpc: "2.0",
|
||||||
|
Id: "mcp-search-docs",
|
||||||
|
Request: jsonrpc.Request{
|
||||||
|
Method: "tools/call",
|
||||||
|
},
|
||||||
|
Params: map[string]any{
|
||||||
|
"name": "search_docs",
|
||||||
|
"arguments": map[string]any{
|
||||||
|
"query": "fast fox jumping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: searchWant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
headers := map[string]string{}
|
||||||
|
|
||||||
|
// Prepare Request Body and Headers
|
||||||
|
if tc.isMcp {
|
||||||
|
reqBytes, err := json.Marshal(tc.requestBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal mcp request: %v", err)
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewBuffer(reqBytes)
|
||||||
|
if sessionId != "" {
|
||||||
|
headers["Mcp-Session-Id"] = sessionId
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send Request
|
||||||
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize Response to get the actual tool result string
|
||||||
|
var got string
|
||||||
|
if tc.isMcp {
|
||||||
|
var mcpResp struct {
|
||||||
|
Result struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"result"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
|
||||||
|
t.Fatalf("error parsing mcp response: %s", err)
|
||||||
|
}
|
||||||
|
if len(mcpResp.Result.Content) > 0 {
|
||||||
|
got = mcpResp.Result.Content[0].Text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var httpResp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(respBody, &httpResp); err != nil {
|
||||||
|
t.Fatalf("error parsing http response: %s", err)
|
||||||
|
}
|
||||||
|
if res, ok := httpResp["result"].(string); ok {
|
||||||
|
got = res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(got, tc.want) {
|
||||||
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupPostgresVectorTable sets up the vector extension and a vector table
|
||||||
|
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
|
||||||
|
t.Helper()
|
||||||
|
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
|
||||||
|
t.Fatalf("failed to create vector extension: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
|
||||||
|
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
content TEXT,
|
||||||
|
embedding vector(768)
|
||||||
|
)`, tableName)
|
||||||
|
|
||||||
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
||||||
|
t.Fatalf("failed to create table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableName, func(t *testing.T) {
|
||||||
|
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
|
||||||
|
t.Errorf("failed to drop table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
|
||||||
|
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
|
||||||
|
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
|
||||||
|
return insertStmt, searchStmt
|
||||||
|
}
|
||||||
@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semantic search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
|
|||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("command initialization returned an error: %s", err)
|
t.Fatalf("command initialization returned an error: %s", err)
|
||||||
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
|
|||||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||||
|
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
|
|||||||
var filteredGot []any
|
var filteredGot []any
|
||||||
for _, item := range got {
|
for _, item := range got {
|
||||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
name, _ := tableMap["object_name"].(string)
|
||||||
|
|
||||||
|
// Only keep the table if it matches expected test tables
|
||||||
|
if name == tableNameParam || name == tableNameAuth {
|
||||||
filteredGot = append(filteredGot, item)
|
filteredGot = append(filteredGot, item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user