Files
genai-toolbox/internal/tools/tools.go
Dr. Strangelove 18017d6545 feat: support alternate accessToken header name (#1968)
## Description

This commit allows a tool to pull an alternate authorization
token from the header of the http request.

This is initially being built for the Looker integration. Looker
uses its own OAuth token. When deploying MCP Toolbox to Cloud
Run, the default token in the "Authorization" header is for
authentication with Cloud Run. An alternate token can be put into
another header by a client such as ADK or any other client that
can programatically set http headers. This token will be used
to authenticate with Looker.

If needed, other sources can use this by setting the header name
in the source config, passing it into the tool config, and returning
the header name in the Tool GetAuthTokenHeaderName() function.

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #1540
2025-11-19 23:00:13 +00:00

143 lines
4.7 KiB
Go

// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tools
import (
"context"
"fmt"
"slices"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
// ToolConfigFactory defines the signature for a function that creates and
// decodes a specific tool's configuration. It takes the context, the tool's
// name, and a YAML decoder to parse the config.
type ToolConfigFactory func(ctx context.Context, name string, decoder *yaml.Decoder) (ToolConfig, error)
var toolRegistry = make(map[string]ToolConfigFactory)
// Register allows individual tool packages to register their configuration
// factory function. This is typically called from an init() function in the
// tool's package. It associates a 'kind' string with a function that can
// produce the specific ToolConfig type. It returns true if the registration was
// successful, and false if a tool with the same kind was already registered.
func Register(kind string, factory ToolConfigFactory) bool {
if _, exists := toolRegistry[kind]; exists {
// Tool with this kind already exists, do not overwrite.
return false
}
toolRegistry[kind] = factory
return true
}
// DecodeConfig looks up the registered factory for the given kind and uses it
// to decode the tool configuration.
func DecodeConfig(ctx context.Context, kind string, name string, decoder *yaml.Decoder) (ToolConfig, error) {
factory, found := toolRegistry[kind]
if !found {
return nil, fmt.Errorf("unknown tool kind: %q", kind)
}
toolConfig, err := factory(ctx, name, decoder)
if err != nil {
return nil, fmt.Errorf("unable to parse tool %q as kind %q: %w", name, kind, err)
}
return toolConfig, nil
}
type ToolConfig interface {
ToolConfigKind() string
Initialize(map[string]sources.Source) (Tool, error)
}
type AccessToken string
func (token AccessToken) ParseBearerToken() (string, error) {
headerParts := strings.Split(string(token), " ")
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
}
return headerParts[1], nil
}
type Tool interface {
Invoke(context.Context, parameters.ParamValues, AccessToken) (any, error)
ParseParams(map[string]any, map[string]map[string]any) (parameters.ParamValues, error)
Manifest() Manifest
McpManifest() McpManifest
Authorized([]string) bool
RequiresClientAuthorization() bool
ToConfig() ToolConfig
GetAuthTokenHeaderName() string
}
// Manifest is the representation of tools sent to Client SDKs.
type Manifest struct {
Description string `json:"description"`
Parameters []parameters.ParameterManifest `json:"parameters"`
AuthRequired []string `json:"authRequired"`
}
// Definition for a tool the MCP client can call.
type McpManifest struct {
// The name of the tool.
Name string `json:"name"`
// A human-readable description of the tool.
Description string `json:"description,omitempty"`
// A JSON Schema object defining the expected parameters for the tool.
InputSchema parameters.McpToolsSchema `json:"inputSchema,omitempty"`
Metadata map[string]any `json:"_meta,omitempty"`
}
func GetMcpManifest(name, desc string, authInvoke []string, params parameters.Parameters) McpManifest {
inputSchema, authParams := params.McpManifest()
mcpManifest := McpManifest{
Name: name,
Description: desc,
InputSchema: inputSchema,
}
// construct metadata, if applicable
metadata := make(map[string]any)
if len(authInvoke) > 0 {
metadata["toolbox/authInvoke"] = authInvoke
}
if len(authParams) > 0 {
metadata["toolbox/authParam"] = authParams
}
if len(metadata) > 0 {
mcpManifest.Metadata = metadata
}
return mcpManifest
}
// Helper function that returns if a tool invocation request is authorized
func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) bool {
if len(authRequiredSources) == 0 {
// no authorization requirement
return true
}
for _, a := range authRequiredSources {
if slices.Contains(verifiedAuthServices, a) {
return true
}
}
return false
}