refactor: decouple Source from Tool (#2204)

This PR update the linking mechanism between Source and Tool.

Tools are directly linked to their Source, either by pointing to the
Source's functions or by assigning values from the source during Tool's
initialization. However, the existing approach means that any
modification to the Source after Tool's initialization might not be
reflected. To address this limitation, each tool should only store a
name reference to the Source, rather than direct link or assigned
values.

Tools will provide interface for `compatibleSource`. This will be used
to determine if a Source is compatible with the Tool.
```
type compatibleSource interface{
    Client() http.Client
    ProjectID() string
}
```

During `Invoke()`, the tool will run the following operations:
* retrieve Source from the `resourceManager` with source's named defined
in Tool's config
* validate Source via `compatibleSource interface{}`
* run the remaining `Invoke()` function. Fields that are needed is
retrieved directly from the source.

With this update, resource manager is also added as input to other
Tool's function that require access to source (e.g.
`RequiresClientAuthorization()`).
This commit is contained in:
Yuan Teoh
2025-12-19 21:27:55 -08:00
committed by GitHub
parent 7daa4111f4
commit 967a72da11
203 changed files with 4242 additions and 6391 deletions

View File

@@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(s.ResourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
s.logger.DebugContext(ctx, errMsg.Error())
_ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound))
return
}
if clientAuth {
if accessToken == "" {
err = fmt.Errorf("tool requires client authorization but access token is missing from the request header")
s.logger.DebugContext(ctx, err.Error())
@@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
if tool.RequiresClientAuthorization(s.ResourceMgr) {
if clientAuth {
// Propagate the original 401/403 error.
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
_ = render.Render(w, r, newErrResponse(err, statusCode))

View File

@@ -77,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool {
return !t.unauthorized
}
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) bool {
func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) {
// defaulted to false
return t.requiresClientAuthrorization
return t.requiresClientAuthrorization, nil
}
func (t MockTool) McpManifest() tools.McpManifest {
@@ -119,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest {
return mcpManifest
}
func (t MockTool) GetAuthTokenHeaderName() string {
return "Authorization"
func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) {
return "Authorization", nil
}
// MockPrompt is used to mock prompts in tests

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -108,10 +108,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -183,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}

View File

@@ -101,10 +101,20 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Get access token
accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName()))
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
if tool.RequiresClientAuthorization(resourceMgr) {
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
@@ -176,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if tool.RequiresClientAuthorization(resourceMgr) {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}