diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index fab10259bbf..604be42499f 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -87,7 +87,7 @@ steps: - "CLOUD_SQL_POSTGRES_REGION=$_REGION" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" secretEnv: - ["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"] + ["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"] volumes: - name: "go" path: "/gopath" @@ -134,7 +134,7 @@ steps: - "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME" - "ALLOYDB_POSTGRES_REGION=$_REGION" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" - secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"] + secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"] volumes: - name: "go" path: "/gopath" @@ -305,7 +305,7 @@ steps: - "POSTGRES_HOST=$_POSTGRES_HOST" - "POSTGRES_PORT=$_POSTGRES_PORT" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" - secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"] + secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"] volumes: - name: "go" path: "/gopath" @@ -964,6 +964,13 @@ steps: availableSecrets: secretManager: + # Common secrets + - versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest + env: CLIENT_ID + - versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest + env: API_KEY + + # Resource-specific secrets - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest env: CLOUD_SQL_POSTGRES_USER - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest @@ -980,8 +987,6 @@ availableSecrets: env: POSTGRES_USER - versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest env: POSTGRES_PASS - - versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest - env: CLIENT_ID - versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest env: NEO4J_USER - versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index 0cbec0d1d26..d34d0074a4d 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + embeddingModels := resourceMgr.GetEmbeddingModelMap() + params, err = tool.EmbedParams(ctx, params, embeddingModels) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index a51bb161eb6..86aa5d9e0b9 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + embeddingModels := resourceMgr.GetEmbeddingModelMap() + params, err = tool.EmbedParams(ctx, params, embeddingModels) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index ccfa5f102f2..f8746d9d9d6 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + embeddingModels := resourceMgr.GetEmbeddingModelMap() + params, err = tool.EmbedParams(ctx, params, embeddingModels) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { diff --git a/internal/server/mcp/v20251125/method.go b/internal/server/mcp/v20251125/method.go index 8d2ae77587f..f67bfb5468e 100644 --- a/internal/server/mcp/v20251125/method.go +++ b/internal/server/mcp/v20251125/method.go @@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re } logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) + embeddingModels := resourceMgr.GetEmbeddingModelMap() + params, err = tool.EmbedParams(ctx, params, embeddingModels) + if err != nil { + err = fmt.Errorf("error embedding parameters: %w", err) + return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err + } + // run tool invocation and generate response. results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { diff --git a/tests/alloydbpg/alloydb_pg_integration_test.go b/tests/alloydbpg/alloydb_pg_integration_test.go index 4e43f64dc93..32463f1ea09 100644 --- a/tests/alloydbpg/alloydb_pg_integration_test.go +++ b/tests/alloydbpg/alloydb_pg_integration_test.go @@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) { teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) defer teardownTable2(t) + // Set up table for semanti search + vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) + defer tearDownVectorTable(t) + // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + // Add semantic search tool config + insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName) + toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt) + toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) diff --git a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go index 66ec7e9865c..dc8ecb27bf9 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go @@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) defer teardownTable2(t) + // Set up table for semantic search + vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) + defer tearDownVectorTable(t) + // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + // Add semantic search tool config + insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName) + toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt) + toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { @@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) tests.RunPostgresListRolesTest(t, ctx, pool) tests.RunPostgresListStoredProcedureTest(t, ctx, pool) + tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox") } // Test connection with different IP type diff --git a/tests/embedding.go b/tests/embedding.go new file mode 100644 index 00000000000..a370ae84d27 --- /dev/null +++ b/tests/embedding.go @@ -0,0 +1,251 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tests contains end to end tests meant to verify the Toolbox Server +// works as expected when executed as a binary. + +package tests + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "testing" + + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/jackc/pgx/v5/pgxpool" +) + +var apiKey = os.Getenv("API_KEY") + +// AddSemanticSearchConfig adds embedding models and semantic search tools to the config +// with configurable tool kind and SQL statements. +func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any { + config["embeddingModels"] = map[string]any{ + "gemini_model": map[string]any{ + "kind": "gemini", + "model": "gemini-embedding-001", + "apiKey": apiKey, + "dimension": 768, + }, + } + + tools, ok := config["tools"].(map[string]any) + if !ok { + t.Fatalf("unable to get tools from config") + } + + tools["insert_docs"] = map[string]any{ + "kind": toolKind, + "source": "my-instance", + "description": "Stores content and its vector embedding into the documents table.", + "statement": insertStmt, + "parameters": []any{ + map[string]any{ + "name": "content", + "type": "string", + "description": "The text content associated with the vector.", + }, + map[string]any{ + "name": "text_to_embed", + "type": "string", + "description": "The text content used to generate the vector.", + "embeddedBy": "gemini_model", + }, + }, + } + + tools["search_docs"] = map[string]any{ + "kind": toolKind, + "source": "my-instance", + "description": "Finds the most semantically similar document to the query vector.", + "statement": searchStmt, + "parameters": []any{ + map[string]any{ + "name": "query", + "type": "string", + "description": "The text content to search for.", + "embeddedBy": "gemini_model", + }, + }, + } + + config["tools"] = tools + return config +} + +// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools +// via both HTTP and MCP endpoints and verifies the output. +func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) { + // Initialize MCP session once for the MCP test cases + sessionId := RunInitialize(t, "2024-11-05") + + tcs := []struct { + name string + api string + isMcp bool + requestBody interface{} + want string + }{ + { + name: "HTTP invoke insert_docs", + api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke", + isMcp: false, + requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`, + want: insertWant, + }, + { + name: "HTTP invoke search_docs", + api: "http://127.0.0.1:5000/api/tool/search_docs/invoke", + isMcp: false, + requestBody: `{"query": "fast fox jumping"}`, + want: searchWant, + }, + { + name: "MCP invoke insert_docs", + api: "http://127.0.0.1:5000/mcp", + isMcp: true, + requestBody: jsonrpc.JSONRPCRequest{ + Jsonrpc: "2.0", + Id: "mcp-insert-docs", + Request: jsonrpc.Request{ + Method: "tools/call", + }, + Params: map[string]any{ + "name": "insert_docs", + "arguments": map[string]any{ + "content": "The quick brown fox jumps over the lazy dog", + "text_to_embed": "The quick brown fox jumps over the lazy dog", + }, + }, + }, + want: mcpInsertWant, + }, + { + name: "MCP invoke search_docs", + api: "http://127.0.0.1:5000/mcp", + isMcp: true, + requestBody: jsonrpc.JSONRPCRequest{ + Jsonrpc: "2.0", + Id: "mcp-search-docs", + Request: jsonrpc.Request{ + Method: "tools/call", + }, + Params: map[string]any{ + "name": "search_docs", + "arguments": map[string]any{ + "query": "fast fox jumping", + }, + }, + }, + want: searchWant, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + var bodyReader io.Reader + headers := map[string]string{} + + // Prepare Request Body and Headers + if tc.isMcp { + reqBytes, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("failed to marshal mcp request: %v", err) + } + bodyReader = bytes.NewBuffer(reqBytes) + if sessionId != "" { + headers["Mcp-Session-Id"] = sessionId + } + } else { + bodyReader = bytes.NewBufferString(tc.requestBody.(string)) + } + + // Send Request + resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) + } + + // Normalize Response to get the actual tool result string + var got string + if tc.isMcp { + var mcpResp struct { + Result struct { + Content []struct { + Text string `json:"text"` + } `json:"content"` + } `json:"result"` + } + if err := json.Unmarshal(respBody, &mcpResp); err != nil { + t.Fatalf("error parsing mcp response: %s", err) + } + if len(mcpResp.Result.Content) > 0 { + got = mcpResp.Result.Content[0].Text + } + } else { + var httpResp map[string]interface{} + if err := json.Unmarshal(respBody, &httpResp); err != nil { + t.Fatalf("error parsing http response: %s", err) + } + if res, ok := httpResp["result"].(string); ok { + got = res + } + } + + if !strings.Contains(got, tc.want) { + t.Fatalf("unexpected value: got %q, want %q", got, tc.want) + } + }) + } +} + +// SetupPostgresVectorTable sets up the vector extension and a vector table +func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) { + t.Helper() + if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil { + t.Fatalf("failed to create vector extension: %v", err) + } + + tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + createTableStmt := fmt.Sprintf(`CREATE TABLE %s ( + id SERIAL PRIMARY KEY, + content TEXT, + embedding vector(768) + )`, tableName) + + if _, err := pool.Exec(ctx, createTableStmt); err != nil { + t.Fatalf("failed to create table %s: %v", tableName, err) + } + + return tableName, func(t *testing.T) { + if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil { + t.Errorf("failed to drop table %s: %v", tableName, err) + } + } +} + +func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) { + insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName) + searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName) + return insertStmt, searchStmt +} diff --git a/tests/postgres/postgres_integration_test.go b/tests/postgres/postgres_integration_test.go index 39c96507ad8..ea34a4a8bcd 100644 --- a/tests/postgres/postgres_integration_test.go +++ b/tests/postgres/postgres_integration_test.go @@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) { teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) defer teardownTable2(t) + // Set up table for semantic search + vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) + defer tearDownVectorTable(t) + // Write config into a file and pass it to command toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") @@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) { toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) + // Add semantic search tool config + insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName) + toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt) + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) if err != nil { t.Fatalf("command initialization returned an error: %s", err) @@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) { tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) tests.RunPostgresListRolesTest(t, ctx, pool) tests.RunPostgresListStoredProcedureTest(t, ctx, pool) + tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox") } diff --git a/tests/tool.go b/tests/tool.go index 488108fc624..6d839d0bf4e 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user var filteredGot []any for _, item := range got { if tableMap, ok := item.(map[string]interface{}); ok { - if schema, ok := tableMap["schema_name"]; ok && schema == "public" { + name, _ := tableMap["object_name"].(string) + + // Only keep the table if it matches expected test tables + if name == tableNameParam || name == tableNameAuth { filteredGot = append(filteredGot, item) } }