Compare commits

..

4 Commits

Author SHA1 Message Date
duwenxin99
83657a9d7a add getParameters 2026-01-15 18:09:45 -05:00
Yuan Teoh
d00b6fdf18 chore: update host validation error to 403 (#2306)
Update error code from 400 to 403 according to MCP
[updates](https://github.com/modelcontextprotocol/modelcontextprotocol/pull/1439)
for invalid origin header.

Also updated hostCheck to only check host, not port.

To test, run Toolbox with the following (also work with port number e.g.
`--allowed-host=127.0.0.1:5000`):
```
go run . --allowed-hosts=127.0.0.1 
```

Test with the following:
```
// curl successfully
curl -H "Host: 127.0.0.1:5000" http://127.0.0.1:5000

// curl successfully
curl -H "Host: 127.0.0.1:3000" http://127.0.0.1:5000

// will show Invalid Host Header error
curl -H "Host: attacker:5000" http://127.0.0.1:5000
```
2026-01-15 21:09:40 +00:00
Yuan Teoh
4d23a3bbf2 feat: add new v20251125 version (#2303)
Add new `v20251125` specs for MCP.
https://modelcontextprotocol.io/specification/2025-11-25
2026-01-15 20:14:11 +00:00
Yuan Teoh
5e0999ebf5 feat: add remaining toolbox server flag (#2272)
Add remaining CLI flags for the server published on official mcp
registry.

ref: https://googleapis.github.io/genai-toolbox/reference/cli/

_note: mcp registry do not support shorthand flag (there are no options
to defined an alternate name). The only way is to define them as
separate named arguments but it may not work well since both would try
to set the same underlying value._
2026-01-15 19:30:40 +00:00
103 changed files with 1417 additions and 220 deletions

View File

@@ -20,6 +20,7 @@ The native SDKs can be combined with MCP clients in many cases.
Toolbox currently supports the following versions of MCP specification:
* [2025-11-25](https://modelcontextprotocol.io/specification/2025-11-25)
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3

4
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

View File

@@ -27,19 +27,21 @@ import (
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
v20251125 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20251125"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
// Update the version used in InitializeResponse when this value is updated.
const LATEST_PROTOCOL_VERSION = v20250618.PROTOCOL_VERSION
const LATEST_PROTOCOL_VERSION = v20251125.PROTOCOL_VERSION
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
var SUPPORTED_PROTOCOL_VERSIONS = []string{
v20241105.PROTOCOL_VERSION,
v20250326.PROTOCOL_VERSION,
v20250618.PROTOCOL_VERSION,
v20251125.PROTOCOL_VERSION,
}
// InitializeResponse runs capability negotiation and protocol version agreement.
@@ -102,6 +104,8 @@ func NotificationHandler(ctx context.Context, body []byte) error {
// This is the Operation phase of the lifecycle for MCP client-server connections.
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch mcpVersion {
case v20251125.PROTOCOL_VERSION:
return v20251125.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250618.PROTOCOL_VERSION:
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250326.PROTOCOL_VERSION:

View File

@@ -0,0 +1,326 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
}
}
// pingHandler handles the "ping" method by returning an empty response.
func pingHandler(id jsonrpc.RequestId) (any, error) {
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: struct{}{},
}, nil
}
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
var req ListToolsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListToolsResult{
Tools: toolset.McpManifest,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var req CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Get access token
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var data map[string]any
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Tool authentication
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
claimsFromAuth := make(map[string]map[string]any)
// if using stdio, header will be nil and auth will not be supported
if header != nil {
for _, aS := range authServices {
claims, err := aS.GetClaimsFromHeader(ctx, header)
if err != nil {
logger.DebugContext(ctx, err.Error())
continue
}
if claims == nil {
// authService not present in header
continue
}
claimsFromAuth[aS.GetName()] = claims
}
}
// Tool authorization check
verifiedAuthServices := make([]string, len(claimsFromAuth))
i := 0
for k := range claimsFromAuth {
verifiedAuthServices[i] = k
i++
}
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)
sliceRes, ok := results.([]any)
if !ok {
sliceRes = []any{results}
}
for _, d := range sliceRes {
text := TextContent{Type: "text"}
dM, err := json.Marshal(d)
if err != nil {
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
} else {
text.Text = string(dM)
}
content = append(content, text)
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: content},
}, nil
}
// promptsListHandler handles the "prompts/list" method.
func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/list request")
var req ListPromptsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListPromptsResult{
Prompts: promptset.McpManifest,
}
logger.DebugContext(ctx, fmt.Sprintf("returning %d prompts", len(promptset.McpManifest)))
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/get request")
var req GetPromptRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts/get request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
promptName := req.Params.Name
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Parse the arguments provided in the request.
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
if err != nil {
err = fmt.Errorf("invalid arguments for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("parsed args: %v", argValues))
// Substitute the argument values into the prompt's messages.
substituted, err := prompt.SubstituteParams(argValues)
if err != nil {
err = fmt.Errorf("error substituting params for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Cast the result to the expected []prompts.Message type.
substitutedMessages, ok := substituted.([]prompts.Message)
if !ok {
err = fmt.Errorf("internal error: SubstituteParams returned unexpected type")
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "substituted params successfully")
// Format the response messages into the required structure.
promptMessages := make([]PromptMessage, len(substitutedMessages))
for i, msg := range substitutedMessages {
promptMessages[i] = PromptMessage{
Role: msg.Role,
Content: TextContent{
Type: "text",
Text: msg.Content,
},
}
}
result := GetPromptResult{
Description: prompt.Manifest().Description,
Messages: promptMessages,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}

View File

@@ -0,0 +1,219 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// SERVER_NAME is the server name used in Implementation.
const SERVER_NAME = "Toolbox"
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
const PROTOCOL_VERSION = "2025-11-25"
// methods that are supported.
const (
PING = "ping"
TOOLS_LIST = "tools/list"
TOOLS_CALL = "tools/call"
PROMPTS_LIST = "prompts/list"
PROMPTS_GET = "prompts/get"
)
/* Empty result */
// EmptyResult represents a response that indicates success but carries no data.
type EmptyResult jsonrpc.Result
/* Pagination */
// Cursor is an opaque token used to represent a cursor for pagination.
type Cursor string
type PaginatedRequest struct {
jsonrpc.Request
Params struct {
// An opaque token representing the current pagination position.
// If provided, the server should return results starting after this cursor.
Cursor Cursor `json:"cursor,omitempty"`
} `json:"params,omitempty"`
}
type PaginatedResult struct {
jsonrpc.Result
// An opaque token representing the pagination position after the last returned result.
// If present, there may be more results available.
NextCursor Cursor `json:"nextCursor,omitempty"`
}
/* Tools */
// Sent from the client to request a list of tools the server has.
type ListToolsRequest struct {
PaginatedRequest
}
// The server's response to a tools/list request from the client.
type ListToolsResult struct {
PaginatedResult
Tools []tools.McpManifest `json:"tools"`
}
// Used by the client to invoke a tool provided by the server.
type CallToolRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params,omitempty"`
}
// The sender or recipient of messages and data in a conversation.
type Role string
const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
// Base for objects that include optional annotations for the client.
// The client can use annotations to inform how objects are used or displayed
type Annotated struct {
Annotations *struct {
// Describes who the intended customer of this object or data is.
// It can include multiple entries to indicate content useful for multiple
// audiences (e.g., `["user", "assistant"]`).
Audience []Role `json:"audience,omitempty"`
// Describes how important this data is for operating the server.
//
// A value of 1 means "most important," and indicates that the data is
// effectively required, while 0 means "least important," and indicates that
// the data is entirely optional.
//
// @TJS-type number
// @minimum 0
// @maximum 1
Priority float64 `json:"priority,omitempty"`
} `json:"annotations,omitempty"`
}
// TextContent represents text provided to or from an LLM.
type TextContent struct {
Annotated
Type string `json:"type"`
// The text content of the message.
Text string `json:"text"`
}
// The server's response to a tool call.
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
type CallToolResult struct {
jsonrpc.Result
// Could be either a TextContent, ImageContent, or EmbeddedResources
// For Toolbox, we will only be sending TextContent
Content []TextContent `json:"content"`
// Whether the tool call ended in an error.
// If not set, this is assumed to be false (the call was successful).
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
IsError bool `json:"isError,omitempty"`
// An optional JSON object that represents the structured result of the tool call.
StructuredContent map[string]any `json:"structuredContent,omitempty"`
}
// Additional properties describing a Tool to clients.
//
// NOTE: all properties in ToolAnnotations are **hints**.
// They are not guaranteed to provide a faithful description of
// tool behavior (including descriptive properties like `title`).
//
// Clients should never make tool use decisions based on ToolAnnotations
// received from untrusted servers.
type ToolAnnotations struct {
// A human-readable title for the tool.
Title string `json:"title,omitempty"`
// If true, the tool does not modify its environment.
// Default: false
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
// If true, the tool may perform destructive updates to its environment.
// If false, the tool performs only additive updates.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: true
DestructiveHint bool `json:"destructiveHint,omitempty"`
// If true, calling the tool repeatedly with the same arguments
// will have no additional effect on the its environment.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: false
IdempotentHint bool `json:"idempotentHint,omitempty"`
// If true, this tool may interact with an "open world" of external
// entities. If false, the tool's domain of interaction is closed.
// For example, the world of a web search tool is open, whereas that
// of a memory tool is not.
// Default: true
OpenWorldHint bool `json:"openWorldHint,omitempty"`
}
/* Prompts */
// Sent from the client to request a list of prompts the server has.
type ListPromptsRequest struct {
PaginatedRequest
}
// The server's response to a prompts/list request from the client.
type ListPromptsResult struct {
PaginatedResult
Prompts []prompts.McpManifest `json:"prompts"`
}
// Used by the client to get a prompt provided by the server.
type GetPromptRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params"`
}
// The server's response to a prompts/get request from the client.
type GetPromptResult struct {
jsonrpc.Result
Description string `json:"description,omitempty"`
Messages []PromptMessage `json:"messages"`
}
// Describes a message returned as part of a prompt.
type PromptMessage struct {
Role string `json:"role"`
Content TextContent `json:"content"`
}

View File

@@ -37,6 +37,7 @@ const jsonrpcVersion = "2.0"
const protocolVersion20241105 = "2024-11-05"
const protocolVersion20250326 = "2025-03-26"
const protocolVersion20250618 = "2025-06-18"
const protocolVersion20251125 = "2025-11-25"
const serverName = "Toolbox"
var basicInputSchema = map[string]any{
@@ -485,6 +486,23 @@ func TestMcpEndpoint(t *testing.T) {
},
},
},
{
name: "version 2025-11-25",
protocol: protocolVersion20251125,
idHeader: false,
initWant: map[string]any{
"jsonrpc": "2.0",
"id": "mcp-initialize",
"result": map[string]any{
"protocolVersion": "2025-11-25",
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
"prompts": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
},
},
},
}
for _, vtc := range versTestCases {
t.Run(vtc.name, func(t *testing.T) {
@@ -494,8 +512,7 @@ func TestMcpEndpoint(t *testing.T) {
if sessionId != "" {
header["Mcp-Session-Id"] = sessionId
}
if vtc.protocol == protocolVersion20250618 {
if vtc.protocol != protocolVersion20241105 && vtc.protocol != protocolVersion20250326 {
header["MCP-Protocol-Version"] = vtc.protocol
}

View File

@@ -304,10 +304,14 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasWildcard := allowedHosts["*"]
_, hostIsAllowed := allowedHosts[r.Host]
hostname := r.Host
if host, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = host
}
_, hostIsAllowed := allowedHosts[hostname]
if !hasWildcard && !hostIsAllowed {
// Return 400 Bad Request or 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusBadRequest)
// Return 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
@@ -406,7 +410,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
}
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
for _, h := range cfg.AllowedHosts {
allowedHostsMap[h] = struct{}{}
hostname := h
if host, _, err := net.SplitHostPort(h); err == nil {
hostname = host
}
allowedHostsMap[hostname] = struct{}{}
}
r.Use(hostCheck(allowedHostsMap))

View File

@@ -14,20 +14,23 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"google.golang.org/api/option"
"golang.org/x/oauth2/google"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -64,19 +67,29 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
if !r.UseClientOAuth {
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}
return s, nil
}
@@ -84,7 +97,8 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *geminidataanalytics.DataChatClient
Client *http.Client
BaseURL string
userAgent string
}
@@ -100,34 +114,63 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
defer cleanup()
return client.QueryData(ctx, req)
}
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
client, err := geminidataanalytics.NewDataChatClient(ctx,
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
return s.Client, func() {}, nil
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
}

View File

@@ -181,9 +181,11 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
}
// Test UseClientAuthorization method
@@ -193,16 +195,15 @@ func TestInitialize(t *testing.T) {
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
client, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, _, err = gdaSrc.GetClient(ctx, "")
_, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}

View File

@@ -198,3 +198,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -204,3 +204,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -209,3 +209,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -180,3 +180,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -184,3 +184,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -184,3 +184,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -174,3 +174,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -179,3 +179,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -179,3 +179,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -303,3 +303,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -175,3 +175,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -150,3 +150,7 @@ var _ tools.Tool = Tool{}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -19,13 +19,11 @@ import (
"encoding/json"
"fmt"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind string = "cloud-gemini-data-analytics-query"
@@ -62,49 +60,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal context from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
RunQuery(context.Context, string, []byte) (any, error)
}
type Config struct {
@@ -141,14 +97,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
return Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}, nil
}
// validate interface
@@ -191,20 +145,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: query,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
}
if t.Context != nil {
req.Context = t.Context.QueryDataContext
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
}
if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}
return source.RunQuery(ctx, tokenStr, req)
return source.RunQuery(ctx, tokenStr, bodyBytes)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
@@ -238,3 +190,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -16,16 +16,19 @@ package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
@@ -71,29 +74,23 @@ func TestParseFromYaml(t *testing.T) {
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
GenerateQueryResult: true,
},
},
},
@@ -111,63 +108,68 @@ func TestParseFromYaml(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// fakeSource implements the compatibleSource interface for testing.
type fakeSource struct {
projectID string
useClientOAuth bool
expectedQuery string
expectedParent string
response *geminidataanalyticspb.QueryDataResponse
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
}
func (f *fakeSource) GetProjectID() string {
return f.projectID
}
func (f *fakeSource) UseClientAuthorization() bool {
return f.useClientOAuth
}
func (f *fakeSource) SourceKind() string {
return "fake-gda-source"
}
func (f *fakeSource) ToConfig() sources.SourceConfig {
return nil
}
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return f, nil
}
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
if req.Prompt != f.expectedQuery {
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
if req.Parent != f.expectedParent {
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
// Basic validation of context/options could be added here if needed,
// but the test case mainly checks if they are passed correctly via successful invocation.
return f.response, nil
return rt.Next.RoundTrip(&newReq)
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
// Minimal fake source
fake := &fakeSource{projectID: "test-project"}
srcs := map[string]sources.Source{
"gda-api-source": fake,
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
}
tcs := []struct {
@@ -186,6 +188,9 @@ func TestInitialize(t *testing.T) {
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
@@ -202,27 +207,92 @@ func TestInitialize(t *testing.T) {
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
projectID := "test-project"
location := "us-central1"
query := "How many accounts who have region in Prague are eligible for loans?"
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Prepare expected response
expectedResp := &geminidataanalyticspb.QueryDataResponse{
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
}
authClient := &http.Client{Transport: authTransport}
fake := &fakeSource{
projectID: projectID,
expectedQuery: query,
expectedParent: expectedParent,
response: expectedResp,
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
}
srcs := map[string]sources.Source{
"mock-gda-source": fake,
"mock-gda-source": mockGdaSource,
}
// Initialize the tool config with context
@@ -231,31 +301,25 @@ func TestInvoke(t *testing.T) {
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: location,
Location: "us-central1", // Set location for the test
Context: &cloudgdatool.QueryDataContext{
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
GenerateQueryResult: true,
},
}
@@ -266,25 +330,24 @@ func TestInvoke(t *testing.T) {
// Prepare parameters for invocation - ONLY query
params := parameters.ParamValues{
{Name: "query", Value: query},
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "")
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
if !ok {
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
}
}

View File

@@ -0,0 +1,116 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -142,3 +142,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -193,3 +193,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -266,3 +266,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -133,3 +133,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,3 +154,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,3 +154,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -168,3 +168,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,3 +154,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,3 +154,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -133,3 +133,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -181,3 +181,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -207,3 +207,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -192,3 +192,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -176,3 +176,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -187,3 +187,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -178,3 +178,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -175,3 +175,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -180,3 +180,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -171,3 +171,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -170,3 +170,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -165,3 +165,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -300,3 +300,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -204,3 +204,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -206,3 +206,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -142,3 +142,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -130,3 +130,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -153,3 +153,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -138,3 +138,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -157,3 +157,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -287,3 +287,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -139,3 +139,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -151,3 +151,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -147,3 +147,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -152,3 +152,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -206,3 +206,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -169,3 +169,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -150,3 +150,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.PayloadParams
}

View File

@@ -148,3 +148,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.PayloadParams
}

View File

@@ -158,3 +158,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -159,3 +159,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,3 +141,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -399,3 +399,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -157,3 +157,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,3 +141,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -166,3 +166,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -234,3 +234,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -180,3 +180,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -303,3 +303,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -170,3 +170,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -145,3 +145,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -153,3 +153,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -165,3 +165,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -166,3 +166,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -20,7 +20,7 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.comcom/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
@@ -85,6 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
Parameters: params,
}
return t, nil
}
@@ -96,6 +97,7 @@ type Tool struct {
Config
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -107,11 +109,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParamValues{}, nil
return parameters.ParseParams(t.Parameters, data, claims)
}
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
return parameters.ParamValues{}, nil
return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil)
}
func (t Tool) Manifest() tools.Manifest {
@@ -137,3 +139,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -235,3 +235,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -164,3 +164,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -161,3 +161,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -175,3 +175,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -164,3 +164,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -186,3 +186,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -198,3 +198,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -212,3 +212,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -173,3 +173,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -126,3 +126,7 @@ func (t *Tool) ToConfig() tools.ToolConfig {
func (t *Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t *Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -147,3 +147,7 @@ func (t *Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t *Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -146,3 +146,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -151,3 +151,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -143,6 +143,10 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
return false, nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.sourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -174,3 +174,7 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -143,3 +143,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -163,6 +163,10 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string,
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}
// GoogleSQL statement for listing graphs
const googleSQLStatement = `
WITH FilterGraphNames AS (

View File

@@ -189,6 +189,10 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string,
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}
// PostgreSQL statement for listing tables
const postgresqlStatement = `
WITH table_info_cte AS (

View File

@@ -185,3 +185,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,3 +141,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -145,3 +145,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

Some files were not shown because too many files have changed in this diff Show More