refactor: Pass Authorization header token to Tool call functions (#1200)

Pass in authorization token to the Tool invocation functions.
Support: https://github.com/googleapis/genai-toolbox/pull/1067
This commit is contained in:
Wenxin Du
2025-08-21 18:20:42 -04:00
committed by GitHub
parent 8ea6a98bd9
commit bffe7b0661
74 changed files with 109 additions and 101 deletions

View File

@@ -217,7 +217,11 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
res, err := tool.Invoke(ctx, params)
// Extract OAuth access token from the "Authorization" header (currently for
// BigQuery end-user credentials usage only)
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
res, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
err = fmt.Errorf("error while invoking tool: %w", err)
s.logger.DebugContext(ctx, err.Error())

View File

@@ -42,7 +42,7 @@ type MockTool struct {
manifest tools.Manifest
}
func (t MockTool) Invoke(context.Context, tools.ParamValues) (any, error) {
func (t MockTool) Invoke(context.Context, tools.ParamValues, tools.AccessToken) (any, error) {
mock := []any{t.Name}
return mock, nil
}

View File

@@ -34,6 +34,7 @@ import (
mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util"
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
@@ -141,7 +142,7 @@ func (s *stdioSession) readInputStream(ctx context.Context) error {
}
return err
}
v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "")
v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", "")
if err != nil {
// errors during the processing of message will generate a valid MCP Error response.
// server can continue to run.
@@ -401,7 +402,9 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
return
}
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName)
accessToken := tools.AccessToken(r.Header.Get("Authorization"))
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, accessToken)
// notifications will return empty string
if res == nil {
// Notifications do not expect a response
@@ -437,7 +440,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
}
// processMcpMessage process the messages received from clients
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string) (string, any, error) {
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, accessToken tools.AccessToken) (string, any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
@@ -492,7 +495,7 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
err = fmt.Errorf("toolset does not exist")
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), body)
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), body, accessToken)
return "", res, err
}
}

View File

@@ -93,14 +93,14 @@ func NotificationHandler(ctx context.Context, body []byte) error {
// ProcessMethod returns a response for the request.
// This is the Operation phase of the lifecycle for MCP client-server connections.
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
switch mcpVersion {
case v20250618.PROTOCOL_VERSION:
return v20250618.ProcessMethod(ctx, id, method, toolset, tools, body)
return v20250618.ProcessMethod(ctx, id, method, toolset, tools, body, accessToken)
case v20250326.PROTOCOL_VERSION:
return v20250326.ProcessMethod(ctx, id, method, toolset, tools, body)
return v20250326.ProcessMethod(ctx, id, method, toolset, tools, body, accessToken)
default:
return v20241105.ProcessMethod(ctx, id, method, toolset, tools, body)
return v20241105.ProcessMethod(ctx, id, method, toolset, tools, body, accessToken)
}
}

View File

@@ -26,14 +26,14 @@ import (
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, tools, body)
return toolsCallHandler(ctx, id, tools, body, accessToken)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
@@ -67,7 +67,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -119,7 +119,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params)
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
text := TextContent{
Type: "text",

View File

@@ -26,14 +26,14 @@ import (
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, tools, body)
return toolsCallHandler(ctx, id, tools, body, accessToken)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
@@ -67,7 +67,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -119,7 +119,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params)
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
text := TextContent{
Type: "text",

View File

@@ -26,14 +26,14 @@ import (
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte) (any, error) {
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, tools, body)
return toolsCallHandler(ctx, id, tools, body, accessToken)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
@@ -67,7 +67,7 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte)
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte) (any, error) {
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[string]tools.Tool, body []byte, accessToken tools.AccessToken) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -119,7 +119,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
}
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, params)
results, err := tool.Invoke(ctx, params, accessToken)
if err != nil {
text := TextContent{
Type: "text",

View File

@@ -156,7 +156,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question

View File

@@ -126,7 +126,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -130,7 +130,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {

View File

@@ -118,7 +118,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {

View File

@@ -120,7 +120,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {

View File

@@ -118,7 +118,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {

View File

@@ -119,7 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {

View File

@@ -131,7 +131,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))

View File

@@ -164,7 +164,7 @@ func getMapParamsType(tparams tools.Parameters, params tools.ParamValues) (map[s
return btParamTypes, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -128,7 +128,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
namedParamsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
if err != nil {

View File

@@ -132,11 +132,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
@@ -181,3 +177,7 @@ func (t Tool) McpManifest() tools.McpManifest {
// Returns the tool MCP manifest
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -119,11 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
// Invoke the tool with the provided parameters
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
@@ -194,3 +190,7 @@ func (t Tool) McpManifest() tools.McpManifest {
// Returns the tool MCP manifest
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -118,11 +118,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize := int32(paramsMap["pageSize"].(int))
@@ -166,3 +162,6 @@ func (t Tool) McpManifest() tools.McpManifest {
// Returns the tool MCP manifest
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}

View File

@@ -120,7 +120,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMapWithDollarPrefix()
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)

View File

@@ -150,7 +150,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
// Get collection path

View File

@@ -115,7 +115,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
if !ok {

View File

@@ -115,7 +115,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
if !ok {

View File

@@ -117,7 +117,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore", t.ProjectId)
release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do()

View File

@@ -116,7 +116,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
var collectionRefs []*firestoreapi.CollectionRef

View File

@@ -267,7 +267,7 @@ type QueryResponse struct {
}
// Invoke executes the Firestore query based on the provided parameters
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
// Parse parameters
queryParams, err := t.parseQueryParameters(params)
if err != nil {

View File

@@ -157,7 +157,7 @@ type ValidationResult struct {
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
mapParams := params.AsMap()
// Get source parameter

View File

@@ -242,7 +242,7 @@ func getHeaders(headerParams tools.Parameters, defaultHeaders map[string]string,
return allHeaders, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
// Calculate request body

View File

@@ -127,7 +127,7 @@ var (
visType string = "vis"
)
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -119,7 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -119,7 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -110,7 +110,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -117,7 +117,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -124,7 +124,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -112,7 +112,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -111,7 +111,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -118,7 +118,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -117,7 +117,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)

View File

@@ -132,7 +132,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
pipelineString, err := tools.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap)

View File

@@ -133,7 +133,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap)

View File

@@ -132,7 +132,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBDeleteOneFilter", t.FilterPayload, paramsMap)

View File

@@ -181,7 +181,7 @@ func getOptions(sortParameters tools.Parameters, projectPayload string, limit in
return opts, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap)

View File

@@ -173,7 +173,7 @@ func getOptions(sortParameters tools.Parameters, projectPayload string, paramsMa
return opts, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap)

View File

@@ -124,7 +124,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
if len(params) == 0 {
return nil, errors.New("no input found")
}

View File

@@ -125,7 +125,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
if len(params) == 0 {
return nil, errors.New("no input found")
}

View File

@@ -143,7 +143,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap)

View File

@@ -144,7 +144,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
filterString, err := tools.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap)

View File

@@ -117,7 +117,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -128,7 +128,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -117,7 +117,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -128,7 +128,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -119,7 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
config := neo4j.ExecuteQueryWithDatabase(t.Database)

View File

@@ -123,7 +123,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
cypherStr, ok := paramsMap["cypher"].(string)
if !ok {

View File

@@ -143,7 +143,7 @@ type Tool struct {
// Invoke executes the tool's main logic: fetching the Neo4j schema.
// It first checks the cache for a valid schema before extracting it from the database.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
// Check if a valid schema is already in the cache.
if cachedSchema, ok := t.cache.Get("schema"); ok {
if schema, ok := cachedSchema.(*types.SchemaInfo); ok {

View File

@@ -115,7 +115,7 @@ type Tool struct {
}
// Invoke executes the SQL statement provided in the parameters.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
sliceParams := params.AsSlice()
sqlStr, ok := sliceParams[0].(string)
if !ok {

View File

@@ -126,7 +126,7 @@ type Tool struct {
}
// Invoke executes the SQL statement with the provided parameters.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -119,7 +119,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -129,7 +129,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -115,7 +115,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
cmds, err := replaceCommandsParams(t.Commands, t.Parameters, params)
if err != nil {
return nil, fmt.Errorf("error replacing commands' parameters: %s", err)

View File

@@ -145,7 +145,7 @@ func processRows(iter *spanner.RowIterator) ([]any, error) {
return out, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -167,7 +167,7 @@ func processRows(iter *spanner.RowIterator) ([]any, error) {
return out, nil
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -125,7 +125,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -115,7 +115,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {

View File

@@ -126,7 +126,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {

View File

@@ -63,8 +63,10 @@ type ToolConfig interface {
Initialize(map[string]sources.Source) (Tool, error)
}
type AccessToken string
type Tool interface {
Invoke(context.Context, ParamValues) (any, error)
Invoke(context.Context, ParamValues, AccessToken) (any, error)
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
Manifest() Manifest
McpManifest() McpManifest

View File

@@ -205,7 +205,7 @@ type Tool struct {
}
// Invoke executes the tool's logic.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)

View File

@@ -85,7 +85,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
durationStr, ok := paramsMap["duration"].(string)

View File

@@ -114,7 +114,7 @@ type Tool struct {
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
// Replace parameters
commands, err := replaceCommandsParams(t.Commands, t.Parameters, params)
if err != nil {