Compare commits

..

9 Commits

Author SHA1 Message Date
rahulpinto19
1dfb9a5378 resolved the comments 2026-02-02 06:36:06 +00:00
rahulpinto19
7908c2256e test 2026-02-02 05:19:54 +00:00
rahulpinto19
a18ab29f4a teardown improvement 2026-01-30 10:07:45 +00:00
rahulpinto19
feee91a18d removed extra spaces 2026-01-30 09:48:06 +00:00
rahulpinto19
468470098e resolve lint failures 2026-01-30 09:25:40 +00:00
rahulpinto19
394c3b78bc patch1 2026-01-30 09:02:40 +00:00
rahulpinto19
4874c12d3f resolved the codesassist comments 2026-01-30 07:21:34 +00:00
rahulpinto19
d0d9b78b3b check before teardown 2026-01-30 06:54:26 +00:00
rahulpinto19
f6587cfaf8 add check before teardown 2026-01-30 06:26:12 +00:00
21 changed files with 298 additions and 356 deletions

View File

@@ -37,6 +37,7 @@ https://dev.mysql.com/doc/refman/8.4/en/sql-prepared-statements.html
https://dev.mysql.com/doc/refman/8.4/en/user-names.html https://dev.mysql.com/doc/refman/8.4/en/user-names.html
# npmjs links can occasionally trigger rate limiting during high-frequency CI builds # npmjs links can occasionally trigger rate limiting during high-frequency CI builds
https://www.npmjs.com/package/@toolbox-sdk/server
https://www.npmjs.com/package/@toolbox-sdk/core https://www.npmjs.com/package/@toolbox-sdk/core
https://www.npmjs.com/package/@toolbox-sdk/adk https://www.npmjs.com/package/@toolbox-sdk/adk
https://www.oceanbase.com/ https://www.oceanbase.com/

View File

@@ -414,10 +414,10 @@ See [Usage Examples](../reference/cli.md#examples).
entries. entries.
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries. * **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
* **Tools:** * **Tools:**
* `search_entries`: Searches for entries in Dataplex Catalog. * `dataplex_search_entries`: Searches for entries in Dataplex Catalog.
* `lookup_entry`: Retrieves a specific entry from Dataplex * `dataplex_lookup_entry`: Retrieves a specific entry from Dataplex
Catalog. Catalog.
* `search_aspect_types`: Finds aspect types relevant to the * `dataplex_search_aspect_types`: Finds aspect types relevant to the
query. query.
## Firestore ## Firestore

View File

@@ -19,6 +19,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
@@ -233,11 +234,10 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth) params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
if err != nil { if err != nil {
// If auth error, return 401 or 403 // If auth error, return 401
var clientServerErr *util.ClientServerError if errors.Is(err, util.ErrUnauthorized) {
if errors.As(err, &clientServerErr) && (clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden) {
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err)) s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
_ = render.Render(w, r, newErrResponse(err, clientServerErr.Code)) _ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
return return
} }
err = fmt.Errorf("provided parameters were invalid: %w", err) err = fmt.Errorf("provided parameters were invalid: %w", err)
@@ -259,49 +259,34 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
// Determine what error to return to the users. // Determine what error to return to the users.
if err != nil { if err != nil {
var tbErr util.ToolboxError errStr := err.Error()
var statusCode int
if errors.As(err, &tbErr) { // Upstream API auth error propagation
switch tbErr.Category() { switch {
case util.CategoryAgent: case strings.Contains(errStr, "Error 401"):
// Agent Errors -> 200 OK statusCode = http.StatusUnauthorized
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err)) case strings.Contains(errStr, "Error 403"):
_ = render.Render(w, r, newErrResponse(err, http.StatusOK)) statusCode = http.StatusForbidden
return }
case util.CategoryServer: if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
// Server Errors -> Check the specific code inside if clientAuth {
var clientServerErr *util.ClientServerError // Propagate the original 401/403 error.
statusCode := http.StatusInternalServerError // Default to 500 s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
if errors.As(err, &clientServerErr) {
if clientServerErr.Code != 0 {
statusCode = clientServerErr.Code
}
}
// Process auth error
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if clientAuth {
// Token error, pass through 401/403
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))
return
}
// ADC/Config error, return 500
statusCode = http.StatusInternalServerError
}
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode)) _ = render.Render(w, r, newErrResponse(err, statusCode))
return return
} }
} else { // ADC lacking permission or credentials configuration error.
// Unknown error -> 500 internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err)) s.logger.ErrorContext(ctx, internalErr.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError)) _ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
return return
} }
err = fmt.Errorf("error while invoking tool: %w", err)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
return
} }
resMarshal, err := json.Marshal(res) resMarshal, err := json.Marshal(res)

View File

@@ -23,6 +23,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@@ -443,12 +444,15 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
code := rpcResponse.Error.Code code := rpcResponse.Error.Code
switch code { switch code {
case jsonrpc.INTERNAL_ERROR: case jsonrpc.INTERNAL_ERROR:
// Map Internal RPC Error (-32603) to HTTP 500
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
case jsonrpc.INVALID_REQUEST: case jsonrpc.INVALID_REQUEST:
var clientServerErr *util.ClientServerError errStr := err.Error()
if errors.As(err, &clientServerErr) { if errors.Is(err, util.ErrUnauthorized) {
w.WriteHeader(clientServerErr.Code) w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 401") {
w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 403") {
w.WriteHeader(http.StatusForbidden)
} }
} }
} }

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
if clientAuth { if clientAuth {
if accessToken == "" { if accessToken == "" {
errMsg := "missing access token in the 'Authorization' header" return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
errMsg,
http.StatusUnauthorized,
nil,
)
} }
} }
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified // Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices) isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized { if !isAuthorized {
err = util.NewClientServerError( err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
} }
logger.DebugContext(ctx, "tool invocation authorized") logger.DebugContext(ctx, "tool invocation authorized")
@@ -202,44 +194,30 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {
var tbErr util.ToolboxError errStr := err.Error()
// Missing authService tokens.
if errors.As(err, &tbErr) { if errors.Is(err, util.ErrUnauthorized) {
switch tbErr.Category() { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
case util.CategoryAgent: }
// MCP - Tool execution error // Upstream auth error
// Return SUCCESS but with IsError: true if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
text := TextContent{ if clientAuth {
Type: "text", // Error with client credentials should pass down to the client
Text: err.Error(), return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
} }
} else { // Auth error with ADC should raise internal 500 error
// Unknown error -> 500
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
} }
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
} }
content := make([]TextContent, 0) content := make([]TextContent, 0)

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
if clientAuth { if clientAuth {
if accessToken == "" { if accessToken == "" {
errMsg := "missing access token in the 'Authorization' header" return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
errMsg,
http.StatusUnauthorized,
nil,
)
} }
} }
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified // Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices) isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized { if !isAuthorized {
err = util.NewClientServerError( err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
} }
logger.DebugContext(ctx, "tool invocation authorized") logger.DebugContext(ctx, "tool invocation authorized")
@@ -202,45 +194,31 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {
var tbErr util.ToolboxError errStr := err.Error()
// Missing authService tokens.
if errors.As(err, &tbErr) { if errors.Is(err, util.ErrUnauthorized) {
switch tbErr.Category() { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
case util.CategoryAgent: }
// MCP - Tool execution error // Upstream auth error
// Return SUCCESS but with IsError: true if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
text := TextContent{ if clientAuth {
Type: "text", // Error with client credentials should pass down to the client
Text: err.Error(), return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
} }
} else { // Auth error with ADC should raise internal 500 error
// Unknown error -> 500
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
} }
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
} }
content := make([]TextContent, 0) content := make([]TextContent, 0)
sliceRes, ok := results.([]any) sliceRes, ok := results.([]any)

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
if clientAuth { if clientAuth {
if accessToken == "" { if accessToken == "" {
errMsg := "missing access token in the 'Authorization' header" return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
errMsg,
http.StatusUnauthorized,
nil,
)
} }
} }
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified // Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices) isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized { if !isAuthorized {
err = util.NewClientServerError( err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
} }
logger.DebugContext(ctx, "tool invocation authorized") logger.DebugContext(ctx, "tool invocation authorized")
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {
var tbErr util.ToolboxError errStr := err.Error()
// Missing authService tokens.
if errors.As(err, &tbErr) { if errors.Is(err, util.ErrUnauthorized) {
switch tbErr.Category() { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
case util.CategoryAgent: }
// MCP - Tool execution error // Upstream auth error
// Return SUCCESS but with IsError: true if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
text := TextContent{ if clientAuth {
Type: "text", // Error with client credentials should pass down to the client
Text: err.Error(), return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
} }
} else { // Auth error with ADC should raise internal 500 error
// Unknown error -> 500
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
} }
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
} }
content := make([]TextContent, 0) content := make([]TextContent, 0)

View File

@@ -21,6 +21,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
if clientAuth { if clientAuth {
if accessToken == "" { if accessToken == "" {
errMsg := "missing access token in the 'Authorization' header" return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
errMsg,
http.StatusUnauthorized,
nil,
)
} }
} }
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// Check if any of the specified auth services is verified // Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices) isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized { if !isAuthorized {
err = util.NewClientServerError( err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
"unauthorized Tool call: Please make sure you specify correct auth headers",
http.StatusUnauthorized,
nil,
)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
} }
logger.DebugContext(ctx, "tool invocation authorized") logger.DebugContext(ctx, "tool invocation authorized")
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {
var tbErr util.ToolboxError errStr := err.Error()
// Missing authService tokens.
if errors.As(err, &tbErr) { if errors.Is(err, util.ErrUnauthorized) {
switch tbErr.Category() { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
case util.CategoryAgent: }
// MCP - Tool execution error // Upstream auth error
// Return SUCCESS but with IsError: true if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
text := TextContent{ if clientAuth {
Type: "text", // Error with client credentials should pass down to the client
Text: err.Error(), return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
case util.CategoryServer:
// MCP Spec - Protocol error
// Return JSON-RPC ERROR
var clientServerErr *util.ClientServerError
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
if errors.As(err, &clientServerErr) {
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
if clientAuth {
rpcCode = jsonrpc.INVALID_REQUEST
} else {
rpcCode = jsonrpc.INTERNAL_ERROR
}
}
}
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
} }
} else { // Auth error with ADC should raise internal 500 error
// Unknown error -> 500
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
} }
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
} }
content := make([]TextContent, 0) content := make([]TextContent, 0)

View File

@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
"id": "tools-call-tool4", "id": "tools-call-tool4",
"error": map[string]any{ "error": map[string]any{
"code": -32600.0, "code": -32600.0,
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>", "message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
}, },
}, },
}, },
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
"id": "tools-call-tool4", "id": "tools-call-tool4",
"error": map[string]any{ "error": map[string]any{
"code": -32600.0, "code": -32600.0,
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>", "message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
}, },
}, },
}, },

View File

@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if source.UseClientAuthorization() { if source.UseClientAuthorization() {
// Use client-side access token // Use client-side access token
if accessToken == "" { if accessToken == "" {
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil) return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
} }
tokenStr, err = accessToken.ParseBearerToken() tokenStr, err = accessToken.ParseBearerToken()
if err != nil { if err != nil {

View File

@@ -17,7 +17,6 @@ package tools
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"slices" "slices"
"strings" "strings"
@@ -81,7 +80,7 @@ type AccessToken string
func (token AccessToken) ParseBearerToken() (string, error) { func (token AccessToken) ParseBearerToken() (string, error) {
headerParts := strings.Split(string(token), " ") headerParts := strings.Split(string(token), " ")
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" { if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil) return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
} }
return headerParts[1], nil return headerParts[1], nil
} }

View File

@@ -1,61 +0,0 @@
// Copyright 2026 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package util
import "fmt"
type ErrorCategory string
const (
CategoryAgent ErrorCategory = "AGENT_ERROR"
CategoryServer ErrorCategory = "SERVER_ERROR"
)
// ToolboxError is the interface all custom errors must satisfy
type ToolboxError interface {
error
Category() ErrorCategory
}
// Agent Errors return 200 to the sender
type AgentError struct {
Msg string
Cause error
}
func (e *AgentError) Error() string { return e.Msg }
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
func (e *AgentError) Unwrap() error { return e.Cause }
func NewAgentError(msg string, cause error) *AgentError {
return &AgentError{Msg: msg, Cause: cause}
}
// ClientServerError returns 4XX/5XX error code
type ClientServerError struct {
Msg string
Code int
Cause error
}
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
func (e *ClientServerError) Unwrap() error { return e.Cause }
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
}

View File

@@ -19,7 +19,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"reflect" "reflect"
"regexp" "regexp"
"slices" "slices"
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
} }
return v, nil return v, nil
} }
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil) return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
} }
// CheckParamRequired checks if a parameter is required based on the required and default field. // CheckParamRequired checks if a parameter is required based on the required and default field.

View File

@@ -17,6 +17,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
} }
return nil, fmt.Errorf("unable to retrieve instrumentation") return nil, fmt.Errorf("unable to retrieve instrumentation")
} }
var ErrUnauthorized = errors.New("unauthorized")

View File

@@ -139,13 +139,24 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
// set up data for param tool // set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t) if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semanti search // Set up table for semanti search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -84,7 +84,6 @@ func TestBigQueryToolEndpoints(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err) t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
} }
// create table name with UUID // create table name with UUID
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
@@ -122,27 +121,42 @@ func TestBigQueryToolEndpoints(t *testing.T) {
// set up data for param tool // set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams) teardownTable1, err := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup param table: %s", err)
}
defer teardownTable1(t) defer teardownTable1(t)
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth)
teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams) teardownTable2, err := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
if err != nil {
t.Fatalf("failed to setup auth table: %s", err)
}
defer teardownTable2(t) defer teardownTable2(t)
// set up data for data type test tool // set up data for data type test tool
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType) createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams) teardownTable3, err := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
if err != nil {
t.Fatalf("failed to setup data type table: %s", err)
}
defer teardownTable3(t) defer teardownTable3(t)
// set up data for forecast tool // set up data for forecast tool
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast) createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams) teardownTable4, err := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
if err != nil {
t.Fatalf("failed to setup forecast table: %s", err)
}
defer teardownTable4(t) defer teardownTable4(t)
// set up data for analyze contribution tool // set up data for analyze contribution tool
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution) createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams) teardownTable5, err := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
if err != nil {
t.Fatalf("failed to setup analyze contribution table: %s", err)
}
defer teardownTable5(t) defer teardownTable5(t)
// Write config into a file and pass it to command // Write config into a file and pass it to command
@@ -231,52 +245,79 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
// Setup allowed table // Setup allowed table
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1) allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1) createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil) teardownAllowed1, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
if err != nil {
t.Fatalf("failed to setup allowed table 1: %s", err)
}
defer teardownAllowed1(t) defer teardownAllowed1(t)
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2) allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2) createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil) teardownAllowed2, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
if err != nil {
t.Fatalf("failed to setup allowed table 2: %s", err)
}
defer teardownAllowed2(t) defer teardownAllowed2(t)
// Setup allowed forecast table // Setup allowed forecast table
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1) allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1) createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1) teardownAllowedForecast1, err:= setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
if err != nil {
t.Fatalf("failed to setup allowed forecast table 1: %s", err)
}
defer teardownAllowedForecast1(t) defer teardownAllowedForecast1(t)
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2) allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2) createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2) teardownAllowedForecast2, err:= setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
if err != nil {
t.Fatalf("failed to setup allowed forecast table 2: %s", err)
}
defer teardownAllowedForecast2(t) defer teardownAllowedForecast2(t)
// Setup disallowed table // Setup disallowed table
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName) disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam) createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil) teardownDisallowed, err:= setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
if err != nil {
t.Fatalf("failed to setup disallowed table: %s", err)
}
defer teardownDisallowed(t) defer teardownDisallowed(t)
// Setup disallowed forecast table // Setup disallowed forecast table
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName) disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName) createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams) teardownDisallowedForecast, err:= setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
if err != nil {
t.Fatalf("failed to setup disallowed forecast table: %s", err)
}
defer teardownDisallowedForecast(t) defer teardownDisallowedForecast(t)
// Setup allowed analyze contribution table // Setup allowed analyze contribution table
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1) allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1) createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1) teardownAllowedAnalyzeContribution1, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
if err != nil {
t.Fatalf("failed to setup allowed analyze contribution table 1: %s", err)
}
defer teardownAllowedAnalyzeContribution1(t) defer teardownAllowedAnalyzeContribution1(t)
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2) allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2) createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2) teardownAllowedAnalyzeContribution2, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
if err != nil {
t.Fatalf("failed to setup allowed analyze contribution table 2: %s", err)
}
defer teardownAllowedAnalyzeContribution2(t) defer teardownAllowedAnalyzeContribution2(t)
// Setup disallowed analyze contribution table // Setup disallowed analyze contribution table
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName) disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName) createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams) teardownDisallowedAnalyzeContribution, err:= setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
if err != nil {
t.Fatalf("failed to setup disallowed analyze contribution table: %s", err)
}
defer teardownDisallowedAnalyzeContribution(t) defer teardownDisallowedAnalyzeContribution(t)
// Configure source with dataset restriction. // Configure source with dataset restriction.
@@ -438,7 +479,10 @@ func TestBigQueryWriteModeBlocked(t *testing.T) {
t.Fatalf("unable to create BigQuery connection: %s", err) t.Fatalf("unable to create BigQuery connection: %s", err)
} }
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams) teardownTable ,err:= setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup BigQuery table: %s", err)
}
defer teardownTable(t) defer teardownTable(t)
toolsFile := map[string]any{ toolsFile := map[string]any{
@@ -623,7 +667,7 @@ func getBigQueryTmplToolStatement() (string, string) {
return tmplSelectCombined, tmplSelectFilterCombined return tmplSelectCombined, tmplSelectFilterCombined
} }
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) { func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) (func(*testing.T), error) {
// Create dataset // Create dataset
dataset := client.Dataset(datasetName) dataset := client.Dataset(datasetName)
_, err := dataset.Metadata(ctx) _, err := dataset.Metadata(ctx)
@@ -699,7 +743,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
} else if err != nil { } else if err != nil {
t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err) t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err)
} }
} }, nil
} }
func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any { func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any {

View File

@@ -124,13 +124,23 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
// set up data for param tool // set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t) if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semantic search // Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -613,31 +613,36 @@ func GetMySQLWants() (string, string, string, string) {
// SetupPostgresSQLTable creates and inserts data into a table of tool // SetupPostgresSQLTable creates and inserts data into a table of tool
// compatible with postgres-sql tool // compatible with postgres-sql tool
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) { func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
err := pool.Ping(ctx) err := pool.Ping(ctx)
if err != nil { if err != nil {
t.Fatalf("unable to connect to test database: %s", err) // Return nil for the function and the error itself
return nil, fmt.Errorf("unable to connect to test database: %w", err)
} }
// Create table // Create table
_, err = pool.Query(ctx, createStatement) _, err = pool.Exec(ctx, createStatement)
if err != nil { if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err) return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
} }
// Insert test data // Insert test data
_, err = pool.Query(ctx, insertStatement, params...) _, err = pool.Exec(ctx, insertStatement, params...)
if err != nil { if err != nil {
t.Fatalf("unable to insert test data: %s", err) // partially cleanup if insert fails
teardown := func(t *testing.T) {
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}
return teardown, fmt.Errorf("unable to insert test data: %w", err)
} }
// Return the cleanup function and nil for error
return func(t *testing.T) { return func(t *testing.T) {
// tear down test _, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
if err != nil { if err != nil {
t.Errorf("Teardown failed: %s", err) t.Errorf("Teardown failed: %s", err)
} }
} }, nil
} }
// SetupMsSQLTable creates and inserts data into a table of tool // SetupMsSQLTable creates and inserts data into a table of tool

View File

@@ -89,12 +89,18 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
// set up data for param tool // set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) teardownTable1, err := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if err != nil {
t.Fatalf("failed to setup Oracle table %s: %v", tableNameParam, err)
}
defer teardownTable1(t) defer teardownTable1(t)
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2, err := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
if err != nil {
t.Fatalf("failed to setup Oracle table %s: %v", tableNameAuth, err)
}
defer teardownTable2(t) defer teardownTable2(t)
// Write config into a file and pass it to command // Write config into a file and pass it to command
@@ -135,31 +141,31 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
} }
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) { func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
err := pool.PingContext(ctx) err := pool.PingContext(ctx)
if err != nil { if err != nil {
t.Fatalf("unable to connect to test database: %s", err) return nil, fmt.Errorf("unable to connect to test database: %w", err)
} }
// Create table // Create table
_, err = pool.QueryContext(ctx, createStatement) _, err = pool.QueryContext(ctx, createStatement)
if err != nil { if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err) return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
} }
// Insert test data // Insert test data
_, err = pool.QueryContext(ctx, insertStatement, params...) _, err = pool.QueryContext(ctx, insertStatement, params...)
if err != nil { if err != nil {
t.Fatalf("unable to insert test data: %s", err) return nil, fmt.Errorf("unable to insert test data: %w", err)
} }
return func(t *testing.T) { return func(t *testing.T) {
// tear down test // tear down test
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName)) _, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s CASCADE CONSTRAINTS", tableName))
if err != nil { if err != nil {
t.Errorf("Teardown failed: %s", err) t.Errorf("Teardown failed: %s", err)
} }
} }, nil
} }
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {

View File

@@ -103,13 +103,24 @@ func TestPostgres(t *testing.T) {
// set up data for param tool // set up data for param tool
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam) createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) // teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
defer teardownTable1(t) teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
if teardownTable1 != nil {
defer teardownTable1(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) if teardownTable2 != nil {
defer teardownTable2(t)
}
if err != nil {
t.Fatalf("Setup failed: %v", err)
}
// Set up table for semantic search // Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool) vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)

View File

@@ -115,23 +115,35 @@ func TestSpannerToolEndpoints(t *testing.T) {
SpannerInstance, SpannerInstance,
SpannerDatabase, SpannerDatabase,
) )
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams) teardownTable1, err := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameParam, err)
}
defer teardownTable1(t) defer teardownTable1(t)
// set up data for auth tool // set up data for auth tool
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getSpannerAuthToolInfo(tableNameAuth) createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getSpannerAuthToolInfo(tableNameAuth)
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams) teardownTable2, err := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameAuth, err)
}
defer teardownTable2(t) defer teardownTable2(t)
// set up data for template param tool // set up data for template param tool
createStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), age INT64) PRIMARY KEY (id)", tableNameTemplateParam) createStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), age INT64) PRIMARY KEY (id)", tableNameTemplateParam)
teardownTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil) teardownTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", tableNameTemplateParam, err)
}
defer teardownTableTmpl(t) defer teardownTableTmpl(t)
// set up for graph tool // set up for graph tool
nodeTableName := "node_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") nodeTableName := "node_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createNodeStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64 NOT NULL) PRIMARY KEY (id)", nodeTableName) createNodeStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64 NOT NULL) PRIMARY KEY (id)", nodeTableName)
teardownNodeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil) teardownNodeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", nodeTableName, err)
}
defer teardownNodeTableTmpl(t) defer teardownNodeTableTmpl(t)
edgeTableName := "edge_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") edgeTableName := "edge_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -143,7 +155,10 @@ func TestSpannerToolEndpoints(t *testing.T) {
) PRIMARY KEY (id, target_id), ) PRIMARY KEY (id, target_id),
INTERLEAVE IN PARENT %[2]s ON DELETE CASCADE INTERLEAVE IN PARENT %[2]s ON DELETE CASCADE
`, edgeTableName, nodeTableName) `, edgeTableName, nodeTableName)
teardownEdgeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil) teardownEdgeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil)
if err != nil {
t.Fatalf("failed to setup Spanner table %s: %v", edgeTableName, err)
}
defer teardownEdgeTableTmpl(t) defer teardownEdgeTableTmpl(t)
graphName := "graph_" + strings.ReplaceAll(uuid.New().String(), "-", "") graphName := "graph_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -243,7 +258,7 @@ func getSpannerAuthToolInfo(tableName string) (string, string, string, map[strin
// setupSpannerTable creates and inserts data into a table of tool // setupSpannerTable creates and inserts data into a table of tool
// compatible with spanner-sql tool // compatible with spanner-sql tool
func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) func(*testing.T) { func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) (func(*testing.T), error) {
// Create table // Create table
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
@@ -251,11 +266,11 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
Statements: []string{createStatement}, Statements: []string{createStatement},
}) })
if err != nil { if err != nil {
t.Fatalf("unable to start create table operation %s: %s", tableName, err) return nil, fmt.Errorf("unable to start create table operation %s: %w", tableName, err)
} }
err = op.Wait(ctx) err = op.Wait(ctx)
if err != nil { if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err) return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
} }
// Insert test data // Insert test data
@@ -269,7 +284,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
return err return err
}) })
if err != nil { if err != nil {
t.Fatalf("unable to insert test data: %s", err) return nil, fmt.Errorf("unable to insert test data: %w", err)
} }
} }
@@ -288,7 +303,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
if opErr != nil { if opErr != nil {
t.Errorf("Teardown failed: %s", opErr) t.Errorf("Teardown failed: %s", opErr)
} }
} }, nil
} }
// setupSpannerGraph creates a graph and inserts data into it. // setupSpannerGraph creates a graph and inserts data into it.