Compare commits

..

9 Commits

Author SHA1 Message Date
rahulpinto19
1dfb9a5378 resolved the comments 2026-02-02 06:36:06 +00:00
rahulpinto19
7908c2256e test 2026-02-02 05:19:54 +00:00
rahulpinto19
a18ab29f4a teardown improvement 2026-01-30 10:07:45 +00:00
rahulpinto19
feee91a18d removed extra spaces 2026-01-30 09:48:06 +00:00
rahulpinto19
468470098e resolve lint failures 2026-01-30 09:25:40 +00:00
rahulpinto19
394c3b78bc patch1 2026-01-30 09:02:40 +00:00
rahulpinto19
4874c12d3f resolved the codesassist comments 2026-01-30 07:21:34 +00:00
rahulpinto19
d0d9b78b3b check before teardown 2026-01-30 06:54:26 +00:00
rahulpinto19
f6587cfaf8 add check before teardown 2026-01-30 06:26:12 +00:00
20 changed files with 295 additions and 356 deletions

View File

@@ -37,6 +37,7 @@ https://dev.mysql.com/doc/refman/8.4/en/sql-prepared-statements.html
https://dev.mysql.com/doc/refman/8.4/en/user-names.html
# npmjs links can occasionally trigger rate limiting during high-frequency CI builds
https://www.npmjs.com/package/@toolbox-sdk/server
https://www.npmjs.com/package/@toolbox-sdk/core
https://www.npmjs.com/package/@toolbox-sdk/adk
https://www.oceanbase.com/

View File

@@ -414,10 +414,10 @@ See [Usage Examples](../reference/cli.md#examples).
entries.
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
* **Tools:**
* `search_entries`: Searches for entries in Dataplex Catalog.
* `lookup_entry`: Retrieves a specific entry from Dataplex
* `dataplex_search_entries`: Searches for entries in Dataplex Catalog.
* `dataplex_lookup_entry`: Retrieves a specific entry from Dataplex
Catalog.
* `search_aspect_types`: Finds aspect types relevant to the
* `dataplex_search_aspect_types`: Finds aspect types relevant to the
query.
## Firestore

View File

@@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -234,10 +235,8 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
if err != nil {
// If auth error, return 401
errMsg := fmt.Sprintf("error parsing authenticated parameters from ID token: %w", err)
var clientServerErr *util.ClientServerError
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
s.logger.DebugContext(ctx, errMsg)
if errors.Is(err, util.ErrUnauthorized) {
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
return
}
@@ -260,49 +259,34 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
// Determine what error to return to the users.
if err != nil {
var tbErr util.ToolboxError
errStr := err.Error()
var statusCode int
if errors.As(err, &tbErr) {
switch tbErr.Category() {
case util.CategoryAgent:
// Agent Errors -> 200 OK
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
return
// Upstream API auth error propagation
switch {
case strings.Contains(errStr, "Error 401"):
statusCode = http.StatusUnauthorized
case strings.Contains(errStr, "Error 403"):
statusCode = http.StatusForbidden
}
case util.CategoryServer:
// Server Errors -> Check the specific code inside
var clientServerErr *util.ClientServerError
statusCode := http.StatusInternalServerError // Default to 500
if errors.As(err, &clientServerErr) {
if clientServerErr.Code != 0 {
statusCode = clientServerErr.Code
}
}
// Process auth error
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if clientAuth {
// Token error, pass through 401/403
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))
return
}
// ADC/Config error, return 500
statusCode = http.StatusInternalServerError
}
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if clientAuth {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))
return
}
} else {
// Unknown error -> 500
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
// ADC lacking permission or credentials configuration error.
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
s.logger.ErrorContext(ctx, internalErr.Error())
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
return
}
err = fmt.Errorf("error while invoking tool: %w", err)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
}
resMarshal, err := json.Marshal(res)

View File

@@ -23,6 +23,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
@@ -443,17 +444,15 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
code := rpcResponse.Error.Code
switch code {
case jsonrpc.INTERNAL_ERROR:
// Map Internal RPC Error (-32603) to HTTP 500
w.WriteHeader(http.StatusInternalServerError)
case jsonrpc.INVALID_REQUEST:
var clientServerErr *util.ClientServerError
if errors.As(err, &clientServerErr) {
switch clientServerErr.Code {
case http.StatusUnauthorized:
w.WriteHeader(http.StatusUnauthorized)
case http.StatusForbidden:
w.WriteHeader(http.StatusForbidden)
}
errStr := err.Error()
if errors.Is(err, util.ErrUnauthorized) {
w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 401") {
w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 403") {
w.WriteHeader(http.StatusForbidden)
}
}
}

View File

@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
"missing access token in the 'Authorization' header",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = util.NewClientServerError(
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
@@ -201,44 +194,30 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
var tbErr util.ToolboxError
if errors.As(err, &tbErr) {
switch tbErr.Category() {
case util.CategoryAgent:
// MCP - Tool execution error
// Return SUCCESS but with IsError: true
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
} else {
// Unknown error -> 500
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)

View File

@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
"missing access token in the 'Authorization' header",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = util.NewClientServerError(
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
@@ -201,45 +194,31 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
var tbErr util.ToolboxError
if errors.As(err, &tbErr) {
switch tbErr.Category() {
case util.CategoryAgent:
// MCP - Tool execution error
// Return SUCCESS but with IsError: true
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
} else {
// Unknown error -> 500
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)
sliceRes, ok := results.([]any)

View File

@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
if clientAuth {
if accessToken == "" {
errMsg := "missing access token in the 'Authorization' header"
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
errMsg,
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = util.NewClientServerError(
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
var tbErr util.ToolboxError
if errors.As(err, &tbErr) {
switch tbErr.Category() {
case util.CategoryAgent:
// MCP - Tool execution error
// Return SUCCESS but with IsError: true
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
} else {
// Unknown error -> 500
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)

View File

@@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -116,11 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
"missing access token in the 'Authorization' header",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
@@ -168,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = util.NewClientServerError(
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
@@ -194,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
var tbErr util.ToolboxError
if errors.As(err, &tbErr) {
switch tbErr.Category() {
case util.CategoryAgent:
// MCP - Tool execution error
// Return SUCCESS but with IsError: true
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
} else {
// Unknown error -> 500
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)

View File

@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if source.UseClientAuthorization() {
// Use client-side access token
if accessToken == "" {
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
}
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {

View File

@@ -17,7 +17,6 @@ package tools
import (
"context"
"fmt"
"net/http"
"slices"
"strings"
@@ -81,7 +80,7 @@ type AccessToken string
func (token AccessToken) ParseBearerToken() (string, error) {
headerParts := strings.Split(string(token), " ")
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
}
return headerParts[1], nil
}

View File

@@ -1,61 +0,0 @@
// Copyright 2026 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import "fmt"
type ErrorCategory string
const (
CategoryAgent ErrorCategory = "AGENT_ERROR"
CategoryServer ErrorCategory = "SERVER_ERROR"
)
// ToolboxError is the interface all custom errors must satisfy
type ToolboxError interface {
error
Category() ErrorCategory
}
// Agent Errors return 200 to the sender
type AgentError struct {
Msg string
Cause error
}
func (e *AgentError) Error() string { return e.Msg }
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
func (e *AgentError) Unwrap() error { return e.Cause }
func NewAgentError(msg string, cause error) *AgentError {
return &AgentError{Msg: msg, Cause: cause}
}
// ClientServerError returns 4XX/5XX error code
type ClientServerError struct {
Msg string
Code int
Cause error
}
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
func (e *ClientServerError) Unwrap() error { return e.Cause }
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
}

View File

@@ -19,7 +19,6 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"reflect"
"regexp"
"slices"
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
}
return v, nil
}
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
}
// CheckParamRequired checks if a parameter is required based on the required and default field.

View File

@@ -17,6 +17,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
}
return nil, fmt.Errorf("unable to retrieve instrumentation")
}
var ErrUnauthorized = errors.New("unauthorized")

View File

@@ -139,13 +139,24 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t)
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semanti search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -84,7 +84,6 @@ func TestBigQueryToolEndpoints(t *testing.T) {
if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
// create table name with UUID
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
@@ -122,27 +121,42 @@ func TestBigQueryToolEndpoints(t *testing.T) {
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
teardownTable1, err := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup param table: %s", err)
}
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth)
teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
teardownTable2, err := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
if err != nil {
t.Fatalf("failed to setup auth table: %s", err)
}
defer teardownTable2(t)
// set up data for data type test tool
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
teardownTable3, err := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
if err != nil {
t.Fatalf("failed to setup data type table: %s", err)
}
defer teardownTable3(t)
// set up data for forecast tool
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
teardownTable4, err := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
if err != nil {
t.Fatalf("failed to setup forecast table: %s", err)
}
defer teardownTable4(t)
// set up data for analyze contribution tool
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
teardownTable5, err := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
if err != nil {
t.Fatalf("failed to setup analyze contribution table: %s", err)
}
defer teardownTable5(t)
// Write config into a file and pass it to command
@@ -231,52 +245,79 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
// Setup allowed table
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
teardownAllowed1, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
if err != nil {
t.Fatalf("failed to setup allowed table 1: %s", err)
}
defer teardownAllowed1(t)
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
teardownAllowed2, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
if err != nil {
t.Fatalf("failed to setup allowed table 2: %s", err)
}
defer teardownAllowed2(t)
// Setup allowed forecast table
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
teardownAllowedForecast1, err:= setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
if err != nil {
t.Fatalf("failed to setup allowed forecast table 1: %s", err)
}
defer teardownAllowedForecast1(t)
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
teardownAllowedForecast2, err:= setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
if err != nil {
t.Fatalf("failed to setup allowed forecast table 2: %s", err)
}
defer teardownAllowedForecast2(t)
// Setup disallowed table
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
teardownDisallowed, err:= setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
if err != nil {
t.Fatalf("failed to setup disallowed table: %s", err)
}
defer teardownDisallowed(t)
// Setup disallowed forecast table
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
teardownDisallowedForecast, err:= setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
if err != nil {
t.Fatalf("failed to setup disallowed forecast table: %s", err)
}
defer teardownDisallowedForecast(t)
// Setup allowed analyze contribution table
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
teardownAllowedAnalyzeContribution1, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
if err != nil {
t.Fatalf("failed to setup allowed analyze contribution table 1: %s", err)
}
defer teardownAllowedAnalyzeContribution1(t)
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
teardownAllowedAnalyzeContribution2, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
if err != nil {
t.Fatalf("failed to setup allowed analyze contribution table 2: %s", err)
}
defer teardownAllowedAnalyzeContribution2(t)
// Setup disallowed analyze contribution table
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
teardownDisallowedAnalyzeContribution, err:= setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
if err != nil {
t.Fatalf("failed to setup disallowed analyze contribution table: %s", err)
}
defer teardownDisallowedAnalyzeContribution(t)
// Configure source with dataset restriction.
@@ -438,7 +479,10 @@ func TestBigQueryWriteModeBlocked(t *testing.T) {
t.Fatalf("unable to create BigQuery connection: %s", err)
}
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
teardownTable ,err:= setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup BigQuery table: %s", err)
}
defer teardownTable(t)
toolsFile := map[string]any{
@@ -623,7 +667,7 @@ func getBigQueryTmplToolStatement() (string, string) {
return tmplSelectCombined, tmplSelectFilterCombined
}
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) (func(*testing.T), error) {
// Create dataset
dataset := client.Dataset(datasetName)
_, err := dataset.Metadata(ctx)
@@ -699,7 +743,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
} else if err != nil {
t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err)
}
}
}, nil
}
func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any {

View File

@@ -124,13 +124,23 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t)
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -613,31 +613,36 @@ func GetMySQLWants() (string, string, string, string) {
// SetupPostgresSQLTable creates and inserts data into a table of tool
// compatible with postgres-sql tool
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
err := pool.Ping(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
// Return nil for the function and the error itself
return nil, fmt.Errorf("unable to connect to test database: %w", err)
}
// Create table
_, err = pool.Query(ctx, createStatement)
_, err = pool.Exec(ctx, createStatement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
}
// Insert test data
_, err = pool.Query(ctx, insertStatement, params...)
_, err = pool.Exec(ctx, insertStatement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
// partially cleanup if insert fails
teardown := func(t *testing.T) {
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}
return teardown, fmt.Errorf("unable to insert test data: %w", err)
}
// Return the cleanup function and nil for error
return func(t *testing.T) {
// tear down test
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}
}, nil
}
// SetupMsSQLTable creates and inserts data into a table of tool

View File

@@ -89,12 +89,18 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
teardownTable1, err := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup Oracle table %s: %v", tableNameParam, err)
}
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
teardownTable2, err := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if err != nil {
t.Fatalf("failed to setup Oracle table %s: %v", tableNameAuth, err)
}
defer teardownTable2(t)
// Write config into a file and pass it to command
@@ -135,31 +141,31 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
}
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
err := pool.PingContext(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
return nil, fmt.Errorf("unable to connect to test database: %w", err)
}
// Create table
_, err = pool.QueryContext(ctx, createStatement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
}
// Insert test data
_, err = pool.QueryContext(ctx, insertStatement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
return nil, fmt.Errorf("unable to insert test data: %w", err)
}
return func(t *testing.T) {
// tear down test
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName))
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s CASCADE CONSTRAINTS", tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}
}, nil
}
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {

View File

@@ -103,13 +103,24 @@ func TestPostgres(t *testing.T) {
// set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t)
// teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t)
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -115,23 +115,35 @@ func TestSpannerToolEndpoints(t *testing.T) {
SpannerInstance,
SpannerDatabase,
)
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams)
teardownTable1, err := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameParam, err)
}
defer teardownTable1(t)
// set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getSpannerAuthToolInfo(tableNameAuth)
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams)
teardownTable2, err := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameAuth, err)
}
defer teardownTable2(t)
// set up data for template param tool
createStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), age INT64) PRIMARY KEY (id)", tableNameTemplateParam)
teardownTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil)
teardownTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameTemplateParam, err)
}
defer teardownTableTmpl(t)
// set up for graph tool
nodeTableName := "node_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createNodeStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64 NOT NULL) PRIMARY KEY (id)", nodeTableName)
teardownNodeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil)
teardownNodeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", nodeTableName, err)
}
defer teardownNodeTableTmpl(t)
edgeTableName := "edge_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -143,7 +155,10 @@ func TestSpannerToolEndpoints(t *testing.T) {
) PRIMARY KEY (id, target_id),
INTERLEAVE IN PARENT %[2]s ON DELETE CASCADE
`, edgeTableName, nodeTableName)
teardownEdgeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil)
teardownEdgeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", edgeTableName, err)
}
defer teardownEdgeTableTmpl(t)
graphName := "graph_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -243,7 +258,7 @@ func getSpannerAuthToolInfo(tableName string) (string, string, string, map[strin
// setupSpannerTable creates and inserts data into a table of tool
// compatible with spanner-sql tool
func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) func(*testing.T) {
func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) (func(*testing.T), error) {
// Create table
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
@@ -251,11 +266,11 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
Statements: []string{createStatement},
})
if err != nil {
t.Fatalf("unable to start create table operation %s: %s", tableName, err)
return nil, fmt.Errorf("unable to start create table operation %s: %w", tableName, err)
}
err = op.Wait(ctx)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
}
// Insert test data
@@ -269,7 +284,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
return err
})
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
return nil, fmt.Errorf("unable to insert test data: %w", err)
}
}
@@ -288,7 +303,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
if opErr != nil {
t.Errorf("Teardown failed: %s", opErr)
}
}
}, nil
}
// setupSpannerGraph creates a graph and inserts data into it.