feat(server): implement Tool call auth error propagation (#1235)

For Toolbox protocol:
Before -  return 400 error for all tool invocation errors. 
After - Propagate auth-related errors (401 & 403) to the client if using
client credentials. If using ADC, raise 500 error instead.

For MCP protocol:
Before -  return 200 with error message in the response body.
After - Propagate auth-related errors (401 & 403) to the client if using
client credentials. If using ADC, raise 500 error instead.
This commit is contained in:
Wenxin Du
2025-08-26 15:27:46 -04:00
committed by GitHub
parent 81c36354cb
commit b94a021ca1
8 changed files with 278 additions and 119 deletions

View File

@@ -16,8 +16,10 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
@@ -210,6 +212,12 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
// If auth error, return 401
if errors.Is(err, tools.ErrUnauthorized) {
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
return
}
err = fmt.Errorf("provided parameters were invalid: %w", err)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
@@ -222,7 +230,33 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
res, err := tool.Invoke(ctx, params, accessToken)
// Determine what error to return to the users.
if err != nil {
errStr := err.Error()
var statusCode int
// Upstream API auth error propagation
switch {
case strings.Contains(errStr, "Error 401"):
statusCode = http.StatusUnauthorized
case strings.Contains(errStr, "Error 403"):
statusCode = http.StatusForbidden
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if tool.RequiresClientAuthorization() {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))
return
}
// ADC lacking permission or credentials configuration error.
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
s.logger.ErrorContext(ctx, internalErr.Error())
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
return
}
err = fmt.Errorf("error while invoking tool: %w", err)
s.logger.DebugContext(ctx, err.Error())
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))

View File

@@ -19,9 +19,11 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
@@ -405,6 +407,11 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, accessToken)
if err != nil {
s.logger.DebugContext(ctx, fmt.Errorf("error invoking tool: %w", err).Error())
}
// notifications will return empty string
if res == nil {
// Notifications do not expect a response
@@ -412,9 +419,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
return
}
if err != nil {
s.logger.DebugContext(ctx, err.Error())
}
// for v20250326, add the `Mcp-Session-Id` header
if v == v20250326.PROTOCOL_VERSION {
@@ -434,6 +438,22 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
s.logger.DebugContext(ctx, "unable to add to event queue")
}
}
if rpcResponse, ok := res.(jsonrpc.JSONRPCError); ok {
code := rpcResponse.Error.Code
switch {
case code == jsonrpc.INTERNAL_ERROR:
w.WriteHeader(http.StatusInternalServerError)
case code == jsonrpc.INVALID_REQUEST:
errStr := err.Error()
if errors.Is(err, tools.ErrUnauthorized) {
w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 401") {
w.WriteHeader(http.StatusUnauthorized)
} else if strings.Contains(errStr, "Error 403") {
w.WriteHeader(http.StatusForbidden)
}
}
}
// send HTTP response
render.JSON(w, r, res)

View File

@@ -18,7 +18,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := tools[toolName]
tool, ok := toolsMap[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
@@ -114,13 +116,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
if !tool.Authorized([]string{}) {
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, tools.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization() {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),

View File

@@ -18,7 +18,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := tools[toolName]
tool, ok := toolsMap[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
@@ -114,13 +116,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
if !tool.Authorized([]string{}) {
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, tools.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization() {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),

View File

@@ -18,7 +18,9 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := tools[toolName]
tool, ok := toolsMap[toolName]
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
@@ -114,13 +116,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
if !tool.Authorized([]string{}) {
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, tools.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization() {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),

View File

@@ -107,7 +107,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
}
return v, nil
}
return nil, fmt.Errorf("missing or invalid authentication header")
return nil, fmt.Errorf("missing or invalid authentication header: %w", ErrUnauthorized)
}
// CheckParamRequired checks if a parameter is required based on the required and default field.

View File

@@ -16,6 +16,7 @@ package tools
import (
"context"
"errors"
"fmt"
"slices"
@@ -91,6 +92,8 @@ type McpManifest struct {
InputSchema McpToolsSchema `json:"inputSchema,omitempty"`
}
var ErrUnauthorized = errors.New("unauthorized")
// Helper function that returns if a tool invocation request is authorized
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
if len(authRequiredSources) == 0 {

View File

@@ -268,120 +268,141 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestHeader map[string]string
requestBody io.Reader
want string
isErr bool
name string
api string
enabled bool
requestHeader map[string]string
requestBody io.Reader
wantStatusCode int
wantBody string
}{
{
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: select1Want,
isErr: false,
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: select1Want,
wantStatusCode: http.StatusOK,
},
{
name: "invoke my-tool",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
want: configs.myToolId3NameAliceWant,
isErr: false,
name: "invoke my-tool",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
wantBody: configs.myToolId3NameAliceWant,
wantStatusCode: http.StatusOK,
},
{
name: "invoke my-tool-by-id with nil response",
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
want: configs.myToolById4Want,
isErr: false,
name: "invoke my-tool-by-id with nil response",
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
wantBody: configs.myToolById4Want,
wantStatusCode: http.StatusOK,
},
{
name: "invoke my-tool-by-name with nil response",
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: configs.nullWant,
isErr: !configs.supportOptionalNullParam,
name: "invoke my-tool-by-name with nil response",
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
enabled: configs.supportOptionalNullParam,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: configs.nullWant,
wantStatusCode: http.StatusOK,
},
{
name: "Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
name: "Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: "",
wantStatusCode: http.StatusBadRequest,
},
{
name: "Invoke my-tool with insufficient parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
isErr: true,
name: "Invoke my-tool with insufficient parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
wantBody: "",
wantStatusCode: http.StatusBadRequest,
},
{
name: "invoke my-array-tool",
api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
want: configs.myToolId3NameAliceWant,
isErr: !configs.supportArrayParam,
name: "invoke my-array-tool",
api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke",
enabled: configs.supportArrayParam,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
wantBody: configs.myToolId3NameAliceWant,
wantStatusCode: http.StatusOK,
},
{
name: "Invoke my-auth-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "[{\"name\":\"Alice\"}]",
isErr: false,
name: "Invoke my-auth-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: "[{\"name\":\"Alice\"}]",
wantStatusCode: http.StatusOK,
},
{
name: "Invoke my-auth-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
name: "Invoke my-auth-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusUnauthorized,
},
{
name: "Invoke my-auth-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
name: "Invoke my-auth-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusUnauthorized,
},
{
name: "Invoke my-auth-required-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
enabled: true,
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: false,
want: select1Want,
wantBody: select1Want,
wantStatusCode: http.StatusOK,
},
{
name: "Invoke my-auth-required-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
name: "Invoke my-auth-required-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
enabled: true,
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusUnauthorized,
},
{
name: "Invoke my-auth-required-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
name: "Invoke my-auth-required-tool without auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusUnauthorized,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
if !tc.enabled {
return
}
// Send Tool invocation request
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
// Add headers
for k, v := range tc.requestHeader {
req.Header.Add(k, v)
}
@@ -391,19 +412,22 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if tc.isErr {
return
}
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
// Check status code
if resp.StatusCode != tc.wantStatusCode {
body, _ := io.ReadAll(resp.Body)
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
}
// skip response body check
if tc.wantBody == "" {
return
}
// Check response body
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
t.Fatalf("error parsing response body: %s", err)
}
got, ok := body["result"].(string)
@@ -411,8 +435,8 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
if got != tc.wantBody {
t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody)
}
})
}
@@ -776,22 +800,21 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
}
sessionId := RunInitialize(t, "2024-11-05")
header := map[string]string{}
if sessionId != "" {
header["Mcp-Session-Id"] = sessionId
}
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody jsonrpc.JSONRPCRequest
requestHeader map[string]string
want string
name string
api string
enabled bool // switch to turn on/off the test case
requestBody jsonrpc.JSONRPCRequest
requestHeader map[string]string
wantStatusCode int
wantBody string
}{
{
name: "MCP Invoke my-tool",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -807,11 +830,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
},
},
},
want: configs.myToolId3NameAliceWant,
wantStatusCode: http.StatusOK,
wantBody: configs.myToolId3NameAliceWant,
},
{
name: "MCP Invoke invalid tool",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -824,11 +849,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
"arguments": map[string]any{},
},
},
want: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
wantStatusCode: http.StatusOK,
wantBody: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
},
{
name: "MCP Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -841,11 +868,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
"arguments": map[string]any{},
},
},
want: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
wantStatusCode: http.StatusOK,
wantBody: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
},
{
name: "MCP Invoke my-tool with insufficient parameters",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -858,11 +887,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
"arguments": map[string]any{"id": 1},
},
},
want: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`,
wantStatusCode: http.StatusOK,
wantBody: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`,
},
{
name: "MCP Invoke my-auth-required-tool",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -875,11 +906,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
"arguments": map[string]any{},
},
},
want: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: `authRequired` is set for the target Tool\"}}",
wantStatusCode: http.StatusUnauthorized,
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: unauthorized\"}}",
},
{
name: "MCP Invoke my-fail-tool",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
requestHeader: map[string]string{},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
@@ -892,36 +925,56 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
"arguments": map[string]any{"id": 1},
},
},
want: myFailToolWant,
wantStatusCode: http.StatusOK,
wantBody: myFailToolWant,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
if !tc.enabled {
return
}
reqMarshal, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("unexpected error during marshaling of request body")
}
_, respBody := runRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), header)
got := string(bytes.TrimSpace(respBody))
// add headers
headers := map[string]string{}
if sessionId != "" {
headers["Mcp-Session-Id"] = sessionId
}
for key, value := range tc.requestHeader {
headers[key] = value
}
if !strings.Contains(got, tc.want) {
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.want)
httpResponse, respBody := runRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), headers)
// Check status code
if httpResponse.StatusCode != tc.wantStatusCode {
t.Errorf("StatusCode mismatch: got %d, want %d", httpResponse.StatusCode, tc.wantStatusCode)
}
// Check response body
got := string(bytes.TrimSpace(respBody))
if !strings.Contains(got, tc.wantBody) {
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.wantBody)
}
})
}
}
func runRequest(t *testing.T, method, url string, body io.Reader, header map[string]string) (*http.Response, []byte) {
func runRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
// Send request
req, err := http.NewRequest(method, url, body)
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
for k, v := range header {
req.Header.Add(k, v)
req.Header.Set("Content-type", "application/json")
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := http.DefaultClient.Do(req)