Compare commits

..

4 Commits

Author SHA1 Message Date
Wenxin Du
3398b941d1 Merge branch 'main' into mcp-sem 2026-01-14 18:38:37 -05:00
duwenxin
2e3baedd1f update other mcp versions 2026-01-14 18:37:56 -05:00
duwenxin
1afb0d557b add testing config 2026-01-14 18:35:42 -05:00
duwenxin
4524c0bd37 add embedding model in mcp 2026-01-13 08:53:06 -05:00
6 changed files with 283 additions and 3 deletions

View File

@@ -305,7 +305,7 @@ steps:
- "POSTGRES_HOST=$_POSTGRES_HOST"
- "POSTGRES_PORT=$_POSTGRES_PORT"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes:
- name: "go"
path: "/gopath"
@@ -964,6 +964,13 @@ steps:
availableSecrets:
secretManager:
# Common secrets
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
env: API_KEY
# Resource-specific secrets
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
env: CLOUD_SQL_POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
@@ -980,8 +987,6 @@ availableSecrets:
env: POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
env: POSTGRES_PASS
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
env: NEO4J_USER
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

View File

@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {

245
tests/embedding.go Normal file
View File

@@ -0,0 +1,245 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tests contains end to end tests meant to verify the Toolbox Server
// works as expected when executed as a binary.
package tests
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/jackc/pgx/v5/pgxpool"
)
var apiKey = os.Getenv("API_KEY")
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
// with configurable tool kind and SQL statements.
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
config["embeddingModels"] = map[string]any{
"gemini_model": map[string]any{
"kind": "gemini",
"model": "gemini-embedding-001",
"apiKey": apiKey,
"dimension": 768,
},
}
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["insert_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Stores content and its vector embedding into the documents table.",
"statement": insertStmt,
"parameters": []any{
map[string]any{
"name": "content",
"type": "string",
"description": "The text content associated with the vector.",
},
map[string]any{
"name": "text_to_embed",
"type": "string",
"description": "The text content used to generate the vector.",
"embeddedBy": "gemini_model",
},
},
}
tools["search_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Finds the most semantically similar document to the query vector.",
"statement": searchStmt,
"parameters": []any{
map[string]any{
"name": "query",
"type": "string",
"description": "The text content to search for.",
"embeddedBy": "gemini_model",
},
},
}
config["tools"] = tools
return config
}
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
// via both HTTP and MCP endpoints and verifies the output.
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, searchWant string) {
// Initialize MCP session once for the MCP test cases
sessionId := RunInitialize(t, "2024-11-05")
tcs := []struct {
name string
api string
isMcp bool
requestBody interface{}
want string
}{
{
name: "HTTP invoke insert_docs",
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
isMcp: false,
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
want: insertWant,
},
{
name: "HTTP invoke search_docs",
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
isMcp: false,
requestBody: `{"query": "fast fox jumping"}`,
want: searchWant,
},
{
name: "MCP invoke insert_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-insert-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "insert_docs",
"arguments": map[string]any{
"content": "The quick brown fox jumps over the lazy dog",
"text_to_embed": "The quick brown fox jumps over the lazy dog",
},
},
},
want: insertWant,
},
{
name: "MCP invoke search_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-search-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "search_docs",
"arguments": map[string]any{
"query": "fast fox jumping",
},
},
},
want: searchWant,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
var bodyReader io.Reader
headers := map[string]string{}
// Prepare Request Body and Headers
if tc.isMcp {
reqBytes, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("failed to marshal mcp request: %v", err)
}
bodyReader = bytes.NewBuffer(reqBytes)
if sessionId != "" {
headers["Mcp-Session-Id"] = sessionId
}
} else {
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
}
// Send Request
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
if resp.StatusCode != http.StatusOK {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Normalize Response to get the actual tool result string
var got string
if tc.isMcp {
var mcpResp struct {
Result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
} `json:"result"`
}
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
t.Fatalf("error parsing mcp response: %s", err)
}
if len(mcpResp.Result.Content) > 0 {
got = mcpResp.Result.Content[0].Text
}
} else {
var httpResp map[string]interface{}
if err := json.Unmarshal(respBody, &httpResp); err != nil {
t.Fatalf("error parsing http response: %s", err)
}
if res, ok := httpResp["result"].(string); ok {
got = res
}
}
if !strings.Contains(got, tc.want) {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
// SetupPostgresVectorTable sets up the vector extension and a vector table
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) func(*testing.T) {
t.Helper()
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
t.Fatalf("failed to create vector extension: %v", err)
}
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
id SERIAL PRIMARY KEY,
content TEXT,
embedding vector(768)
)`, tableName)
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
t.Fatalf("failed to create table %s: %v", tableName, err)
}
return func(t *testing.T) {
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
t.Errorf("failed to drop table %s: %v", tableName, err)
}
}
}

View File

@@ -111,12 +111,20 @@ func TestPostgres(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
// Set up table for semanti search
tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
// Add semantic search tool config
insertStmt := "INSERT INTO documents (content, embedding) VALUES ($1, $2)"
searchStmt := "SELECT id, content, embedding <-> $1 AS distance FROM documents ORDER BY distance LIMIT 1"
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
tests.RunSemanticSearchToolInvokeTest(t, "null", "The quick brown fox")
}