mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-04 04:05:22 -05:00
Compare commits
4 Commits
err-api
...
temp-Delet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff874ef385 | ||
|
|
a6335c3797 | ||
|
|
63cdea2cd0 | ||
|
|
a18fe045dd |
@@ -414,10 +414,10 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
entries.
|
entries.
|
||||||
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
|
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `search_entries`: Searches for entries in Dataplex Catalog.
|
* `dataplex_search_entries`: Searches for entries in Dataplex Catalog.
|
||||||
* `lookup_entry`: Retrieves a specific entry from Dataplex
|
* `dataplex_lookup_entry`: Retrieves a specific entry from Dataplex
|
||||||
Catalog.
|
Catalog.
|
||||||
* `search_aspect_types`: Finds aspect types relevant to the
|
* `dataplex_search_aspect_types`: Finds aspect types relevant to the
|
||||||
query.
|
query.
|
||||||
|
|
||||||
## Firestore
|
## Firestore
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -234,10 +235,8 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401
|
// If auth error, return 401
|
||||||
errMsg := fmt.Sprintf("error parsing authenticated parameters from ID token: %w", err)
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
var clientServerErr *util.ClientServerError
|
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||||
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
|
|
||||||
s.logger.DebugContext(ctx, errMsg)
|
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -260,50 +259,35 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
var statusCode int
|
||||||
|
|
||||||
if errors.As(err, &tbErr) {
|
// Upstream API auth error propagation
|
||||||
switch tbErr.Category() {
|
switch {
|
||||||
case util.CategoryAgent:
|
case strings.Contains(errStr, "Error 401"):
|
||||||
// Agent Errors -> 200 OK
|
statusCode = http.StatusUnauthorized
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
case strings.Contains(errStr, "Error 403"):
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
statusCode = http.StatusForbidden
|
||||||
return
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// Server Errors -> Check the specific code inside
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
statusCode := http.StatusInternalServerError // Default to 500
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code != 0 {
|
|
||||||
statusCode = clientServerErr.Code
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process auth error
|
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
// Token error, pass through 401/403
|
// Propagate the original 401/403 error.
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// ADC/Config error, return 500
|
// ADC lacking permission or credentials configuration error.
|
||||||
statusCode = http.StatusInternalServerError
|
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||||
}
|
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||||
|
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||||
// Unknown error -> 500
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -443,20 +444,18 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
// Map Internal RPC Error (-32603) to HTTP 500
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
var clientServerErr *util.ClientServerError
|
errStr := err.Error()
|
||||||
if errors.As(err, &clientServerErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch clientServerErr.Code {
|
|
||||||
case http.StatusUnauthorized:
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
case http.StatusForbidden:
|
} else if strings.Contains(errStr, "Error 401") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
} else if strings.Contains(errStr, "Error 403") {
|
||||||
w.WriteHeader(http.StatusForbidden)
|
w.WriteHeader(http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// send HTTP response
|
// send HTTP response
|
||||||
render.JSON(w, r, res)
|
render.JSON(w, r, res)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
"missing access token in the 'Authorization' header",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -201,13 +194,21 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Upstream auth error
|
||||||
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
if errors.As(err, &tbErr) {
|
|
||||||
switch tbErr.Category() {
|
|
||||||
case util.CategoryAgent:
|
|
||||||
// MCP - Tool execution error
|
|
||||||
// Return SUCCESS but with IsError: true
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -217,28 +218,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,11 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
"missing access token in the 'Authorization' header",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -201,13 +194,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -217,29 +217,8 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -195,13 +187,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -211,28 +210,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,11 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.NewClientServerError(
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
"missing access token in the 'Authorization' header",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,13 +187,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,28 +210,6 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -81,7 +80,7 @@ type AccessToken string
|
|||||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||||
headerParts := strings.Split(string(token), " ")
|
headerParts := strings.Split(string(token), " ")
|
||||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||||
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
return headerParts[1], nil
|
return headerParts[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
package util
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
type ErrorCategory string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
|
||||||
CategoryServer ErrorCategory = "SERVER_ERROR"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToolboxError is the interface all custom errors must satisfy
|
|
||||||
type ToolboxError interface {
|
|
||||||
error
|
|
||||||
Category() ErrorCategory
|
|
||||||
}
|
|
||||||
|
|
||||||
// Agent Errors return 200 to the sender
|
|
||||||
type AgentError struct {
|
|
||||||
Msg string
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *AgentError) Error() string { return e.Msg }
|
|
||||||
|
|
||||||
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
|
||||||
|
|
||||||
func (e *AgentError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewAgentError(msg string, cause error) *AgentError {
|
|
||||||
return &AgentError{Msg: msg, Cause: cause}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientServerError returns 4XX/5XX error code
|
|
||||||
type ClientServerError struct {
|
|
||||||
Msg string
|
|
||||||
Code int
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ClientServerError) Error() string { return fmt.Sprintf("%s: %v", e.Msg, e.Cause) }
|
|
||||||
|
|
||||||
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
|
||||||
|
|
||||||
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
|
||||||
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnauthorized = errors.New("unauthorized")
|
||||||
|
|||||||
@@ -614,6 +614,8 @@ func GetMySQLWants() (string, string, string, string) {
|
|||||||
// SetupPostgresSQLTable creates and inserts data into a table of tool
|
// SetupPostgresSQLTable creates and inserts data into a table of tool
|
||||||
// compatible with postgres-sql tool
|
// compatible with postgres-sql tool
|
||||||
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||||
|
|
||||||
|
|
||||||
err := pool.Ping(ctx)
|
err := pool.Ping(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to connect to test database: %s", err)
|
t.Fatalf("unable to connect to test database: %s", err)
|
||||||
@@ -621,9 +623,10 @@ func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool
|
|||||||
|
|
||||||
// Create table
|
// Create table
|
||||||
_, err = pool.Query(ctx, createStatement)
|
_, err = pool.Query(ctx, createStatement)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
|
// t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||||
|
// }
|
||||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||||
}
|
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
_, err = pool.Query(ctx, insertStatement, params...)
|
_, err = pool.Query(ctx, insertStatement, params...)
|
||||||
|
|||||||
@@ -94,7 +94,10 @@ func TestPostgres(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cleanup test environment
|
// cleanup test environment
|
||||||
tests.CleanupPostgresTables(t, ctx, pool)
|
// tests.CleanupPostgresTables(t, ctx, pool)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
tests.CleanupPostgresTables(t, context.Background(), pool)
|
||||||
|
})
|
||||||
|
|
||||||
// create table name with UUID
|
// create table name with UUID
|
||||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
|||||||
Reference in New Issue
Block a user