mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat(server): implement Tool call auth error propagation (#1235)
For Toolbox protocol: Before - return 400 error for all tool invocation errors. After - Propagate auth-related errors (401 & 403) to the client if using client credentials. If using ADC, raise 500 error instead. For MCP protocol: Before - return 200 with error message in the response body. After - Propagate auth-related errors (401 & 403) to the client if using client credentials. If using ADC, raise 500 error instead.
This commit is contained in:
@@ -16,8 +16,10 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -210,6 +212,12 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
// If auth error, return 401
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
@@ -222,7 +230,33 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
res, err := tool.Invoke(ctx, params, accessToken)
|
||||
|
||||
// Determine what error to return to the users.
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
var statusCode int
|
||||
|
||||
// Upstream API auth error propagation
|
||||
switch {
|
||||
case strings.Contains(errStr, "Error 401"):
|
||||
statusCode = http.StatusUnauthorized
|
||||
case strings.Contains(errStr, "Error 403"):
|
||||
statusCode = http.StatusForbidden
|
||||
}
|
||||
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if tool.RequiresClientAuthorization() {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
return
|
||||
}
|
||||
// ADC lacking permission or credentials configuration error.
|
||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
|
||||
@@ -19,9 +19,11 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -405,6 +407,11 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
|
||||
|
||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, accessToken)
|
||||
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, fmt.Errorf("error invoking tool: %w", err).Error())
|
||||
}
|
||||
|
||||
// notifications will return empty string
|
||||
if res == nil {
|
||||
// Notifications do not expect a response
|
||||
@@ -412,9 +419,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
}
|
||||
|
||||
// for v20250326, add the `Mcp-Session-Id` header
|
||||
if v == v20250326.PROTOCOL_VERSION {
|
||||
@@ -434,6 +438,22 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.DebugContext(ctx, "unable to add to event queue")
|
||||
}
|
||||
}
|
||||
if rpcResponse, ok := res.(jsonrpc.JSONRPCError); ok {
|
||||
code := rpcResponse.Error.Code
|
||||
switch {
|
||||
case code == jsonrpc.INTERNAL_ERROR:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case code == jsonrpc.INVALID_REQUEST:
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 401") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 403") {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// send HTTP response
|
||||
render.JSON(w, r, res)
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := tools[toolName]
|
||||
tool, ok := toolsMap[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
@@ -114,13 +116,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization() {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := tools[toolName]
|
||||
tool, ok := toolsMap[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
@@ -114,13 +116,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization() {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
|
||||
@@ -18,7 +18,9 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -67,7 +69,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -83,7 +85,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := tools[toolName]
|
||||
tool, ok := toolsMap[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
@@ -114,13 +116,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
if !tool.Authorized([]string{}) {
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool")
|
||||
err = fmt.Errorf("unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: %w", tools.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, tools.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if tool.RequiresClientAuthorization() {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
|
||||
@@ -107,7 +107,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
return nil, fmt.Errorf("missing or invalid authentication header")
|
||||
return nil, fmt.Errorf("missing or invalid authentication header: %w", ErrUnauthorized)
|
||||
}
|
||||
|
||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||
|
||||
@@ -16,6 +16,7 @@ package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
@@ -91,6 +92,8 @@ type McpManifest struct {
|
||||
InputSchema McpToolsSchema `json:"inputSchema,omitempty"`
|
||||
}
|
||||
|
||||
var ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
// Helper function that returns if a tool invocation request is authorized
|
||||
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
|
||||
if len(authRequiredSources) == 0 {
|
||||
|
||||
265
tests/tool.go
265
tests/tool.go
@@ -268,120 +268,141 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
name string
|
||||
api string
|
||||
enabled bool
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: select1Want,
|
||||
isErr: false,
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: select1Want,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke my-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
||||
want: configs.myToolId3NameAliceWant,
|
||||
isErr: false,
|
||||
name: "invoke my-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
||||
wantBody: configs.myToolId3NameAliceWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke my-tool-by-id with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
|
||||
want: configs.myToolById4Want,
|
||||
isErr: false,
|
||||
name: "invoke my-tool-by-id with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
|
||||
wantBody: configs.myToolById4Want,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invoke my-tool-by-name with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: configs.nullWant,
|
||||
isErr: !configs.supportOptionalNullParam,
|
||||
name: "invoke my-tool-by-name with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
|
||||
enabled: configs.supportOptionalNullParam,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: configs.nullWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: "",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
||||
wantBody: "",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke my-array-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
|
||||
want: configs.myToolId3NameAliceWant,
|
||||
isErr: !configs.supportArrayParam,
|
||||
name: "invoke my-array-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-array-tool/invoke",
|
||||
enabled: configs.supportArrayParam,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
|
||||
wantBody: configs.myToolId3NameAliceWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"name\":\"Alice\"}]",
|
||||
isErr: false,
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: "[{\"name\":\"Alice\"}]",
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-auth-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: false,
|
||||
want: select1Want,
|
||||
|
||||
wantBody: select1Want,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-auth-required-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-required-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
name: "Invoke my-auth-required-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.enabled {
|
||||
return
|
||||
}
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
// Add headers
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
@@ -391,19 +412,22 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
// Check status code
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("StatusCode mismatch: got %d, want %d. Response body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
|
||||
// skip response body check
|
||||
if tc.wantBody == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
t.Fatalf("error parsing response body: %s", err)
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
@@ -411,8 +435,8 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
if got != tc.wantBody {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -776,22 +800,21 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
}
|
||||
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody jsonrpc.JSONRPCRequest
|
||||
requestHeader map[string]string
|
||||
want string
|
||||
name string
|
||||
api string
|
||||
enabled bool // switch to turn on/off the test case
|
||||
requestBody jsonrpc.JSONRPCRequest
|
||||
requestHeader map[string]string
|
||||
wantStatusCode int
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "MCP Invoke my-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -807,11 +830,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
},
|
||||
},
|
||||
},
|
||||
want: configs.myToolId3NameAliceWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: configs.myToolId3NameAliceWant,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke invalid tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -824,11 +849,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -841,11 +868,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -858,11 +887,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: `{"jsonrpc":"2.0","id":"invoke-insufficient-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"name\" is required"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-auth-required-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -875,11 +906,13 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: `authRequired` is set for the target Tool\"}}",
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
wantBody: "{\"jsonrpc\":\"2.0\",\"id\":\"invoke my-auth-required-tool\",\"error\":{\"code\":-32600,\"message\":\"unauthorized Tool call: `authRequired` is set for the target Tool but isn't supported through MCP Tool call: unauthorized\"}}",
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-fail-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
@@ -892,36 +925,56 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant string, options ...McpTes
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
want: myFailToolWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantBody: myFailToolWant,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.enabled {
|
||||
return
|
||||
}
|
||||
reqMarshal, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error during marshaling of request body")
|
||||
}
|
||||
|
||||
_, respBody := runRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), header)
|
||||
got := string(bytes.TrimSpace(respBody))
|
||||
// add headers
|
||||
headers := map[string]string{}
|
||||
if sessionId != "" {
|
||||
headers["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
for key, value := range tc.requestHeader {
|
||||
headers[key] = value
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.want)
|
||||
httpResponse, respBody := runRequest(t, http.MethodPost, tc.api, bytes.NewBuffer(reqMarshal), headers)
|
||||
|
||||
// Check status code
|
||||
if httpResponse.StatusCode != tc.wantStatusCode {
|
||||
t.Errorf("StatusCode mismatch: got %d, want %d", httpResponse.StatusCode, tc.wantStatusCode)
|
||||
}
|
||||
|
||||
// Check response body
|
||||
got := string(bytes.TrimSpace(respBody))
|
||||
if !strings.Contains(got, tc.wantBody) {
|
||||
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, tc.wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runRequest(t *testing.T, method, url string, body io.Reader, header map[string]string) (*http.Response, []byte) {
|
||||
func runRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
|
||||
// Send request
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range header {
|
||||
req.Header.Add(k, v)
|
||||
req.Header.Set("Content-type", "application/json")
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
|
||||
Reference in New Issue
Block a user