Compare commits

..

3 Commits

Author SHA1 Message Date
Juexin Wang
9bb4eee494 refactor(tools/cloudgda): update to use google-cloud-go sdk types
Removes types.go and uses geminidataanalyticspb types with wrapper structs for YAML decoding.
2026-01-15 10:24:10 -08:00
Juexin Wang
9f5b04cf73 Refactor Cloud GDA source to support per-request client authorization 2026-01-14 16:10:50 -08:00
Juexin Wang
66d6b58c4f intorduce Go SDK for GDA tool and source 2026-01-14 16:10:50 -08:00
103 changed files with 218 additions and 1415 deletions

View File

@@ -20,7 +20,6 @@ The native SDKs can be combined with MCP clients in many cases.
Toolbox currently supports the following versions of MCP specification:
* [2025-11-25](https://modelcontextprotocol.io/specification/2025-11-25)
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3

4
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

View File

@@ -27,21 +27,19 @@ import (
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
v20251125 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20251125"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
// Update the version used in InitializeResponse when this value is updated.
const LATEST_PROTOCOL_VERSION = v20251125.PROTOCOL_VERSION
const LATEST_PROTOCOL_VERSION = v20250618.PROTOCOL_VERSION
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
var SUPPORTED_PROTOCOL_VERSIONS = []string{
v20241105.PROTOCOL_VERSION,
v20250326.PROTOCOL_VERSION,
v20250618.PROTOCOL_VERSION,
v20251125.PROTOCOL_VERSION,
}
// InitializeResponse runs capability negotiation and protocol version agreement.
@@ -104,8 +102,6 @@ func NotificationHandler(ctx context.Context, body []byte) error {
// This is the Operation phase of the lifecycle for MCP client-server connections.
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch mcpVersion {
case v20251125.PROTOCOL_VERSION:
return v20251125.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250618.PROTOCOL_VERSION:
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
case v20250326.PROTOCOL_VERSION:

View File

@@ -1,326 +0,0 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
)
// ProcessMethod returns a response for the request.
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
switch method {
case PING:
return pingHandler(id)
case TOOLS_LIST:
return toolsListHandler(id, toolset, body)
case TOOLS_CALL:
return toolsCallHandler(ctx, id, resourceMgr, body, header)
case PROMPTS_LIST:
return promptsListHandler(ctx, id, promptset, body)
case PROMPTS_GET:
return promptsGetHandler(ctx, id, resourceMgr, body)
default:
err := fmt.Errorf("invalid method %s", method)
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
}
}
// pingHandler handles the "ping" method by returning an empty response.
func pingHandler(id jsonrpc.RequestId) (any, error) {
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: struct{}{},
}, nil
}
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
var req ListToolsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListToolsResult{
Tools: toolset.McpManifest,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// toolsCallHandler generate a response for tools call.
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
authServices := resourceMgr.GetAuthServiceMap()
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var req CallToolRequest
if err = json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp tools call request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
toolName := req.Params.Name
toolArgument := req.Params.Arguments
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
tool, ok := resourceMgr.GetTool(toolName)
if !ok {
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Get access token
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
// Check if this specific tool requires the standard authorization header
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
if err != nil {
errMsg := fmt.Errorf("error during invocation: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
}
if clientAuth {
if accessToken == "" {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
}
}
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
aMarshal, err := json.Marshal(toolArgument)
if err != nil {
err = fmt.Errorf("unable to marshal tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
var data map[string]any
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
err = fmt.Errorf("unable to decode tools argument: %w", err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Tool authentication
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
claimsFromAuth := make(map[string]map[string]any)
// if using stdio, header will be nil and auth will not be supported
if header != nil {
for _, aS := range authServices {
claims, err := aS.GetClaimsFromHeader(ctx, header)
if err != nil {
logger.DebugContext(ctx, err.Error())
continue
}
if claims == nil {
// authService not present in header
continue
}
claimsFromAuth[aS.GetName()] = claims
}
}
// Tool authorization check
verifiedAuthServices := make([]string, len(claimsFromAuth))
i := 0
for k := range claimsFromAuth {
verifiedAuthServices[i] = k
i++
}
// Check if any of the specified auth services is verified
isAuthorized := tool.Authorized(verifiedAuthServices)
if !isAuthorized {
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
logger.DebugContext(ctx, "tool invocation authorized")
params, err := tool.ParseParams(data, claimsFromAuth)
if err != nil {
err = fmt.Errorf("provided parameters were invalid: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
// run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil {
errStr := err.Error()
// Missing authService tokens.
if errors.Is(err, util.ErrUnauthorized) {
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Upstream auth error
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
if clientAuth {
// Error with client credentials should pass down to the client
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
// Auth error with ADC should raise internal 500 error
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
text := TextContent{
Type: "text",
Text: err.Error(),
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
}, nil
}
content := make([]TextContent, 0)
sliceRes, ok := results.([]any)
if !ok {
sliceRes = []any{results}
}
for _, d := range sliceRes {
text := TextContent{Type: "text"}
dM, err := json.Marshal(d)
if err != nil {
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
} else {
text.Text = string(dM)
}
content = append(content, text)
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: CallToolResult{Content: content},
}, nil
}
// promptsListHandler handles the "prompts/list" method.
func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/list request")
var req ListPromptsRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts list request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
result := ListPromptsResult{
Prompts: promptset.McpManifest,
}
logger.DebugContext(ctx, fmt.Sprintf("returning %d prompts", len(promptset.McpManifest)))
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}
// promptsGetHandler handles the "prompts/get" method.
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
// retrieve logger from context
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "handling prompts/get request")
var req GetPromptRequest
if err := json.Unmarshal(body, &req); err != nil {
err = fmt.Errorf("invalid mcp prompts/get request: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
}
promptName := req.Params.Name
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
prompt, ok := resourceMgr.GetPrompt(promptName)
if !ok {
err := fmt.Errorf("prompt with name %q does not exist", promptName)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// Parse the arguments provided in the request.
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
if err != nil {
err = fmt.Errorf("invalid arguments for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
logger.DebugContext(ctx, fmt.Sprintf("parsed args: %v", argValues))
// Substitute the argument values into the prompt's messages.
substituted, err := prompt.SubstituteParams(argValues)
if err != nil {
err = fmt.Errorf("error substituting params for prompt %q: %w", promptName, err)
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
// Cast the result to the expected []prompts.Message type.
substitutedMessages, ok := substituted.([]prompts.Message)
if !ok {
err = fmt.Errorf("internal error: SubstituteParams returned unexpected type")
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
}
logger.DebugContext(ctx, "substituted params successfully")
// Format the response messages into the required structure.
promptMessages := make([]PromptMessage, len(substitutedMessages))
for i, msg := range substitutedMessages {
promptMessages[i] = PromptMessage{
Role: msg.Role,
Content: TextContent{
Type: "text",
Text: msg.Content,
},
}
}
result := GetPromptResult{
Description: prompt.Manifest().Description,
Messages: promptMessages,
}
return jsonrpc.JSONRPCResponse{
Jsonrpc: jsonrpc.JSONRPC_VERSION,
Id: id,
Result: result,
}, nil
}

View File

@@ -1,219 +0,0 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package v20251125
import (
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/tools"
)
// SERVER_NAME is the server name used in Implementation.
const SERVER_NAME = "Toolbox"
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
const PROTOCOL_VERSION = "2025-11-25"
// methods that are supported.
const (
PING = "ping"
TOOLS_LIST = "tools/list"
TOOLS_CALL = "tools/call"
PROMPTS_LIST = "prompts/list"
PROMPTS_GET = "prompts/get"
)
/* Empty result */
// EmptyResult represents a response that indicates success but carries no data.
type EmptyResult jsonrpc.Result
/* Pagination */
// Cursor is an opaque token used to represent a cursor for pagination.
type Cursor string
type PaginatedRequest struct {
jsonrpc.Request
Params struct {
// An opaque token representing the current pagination position.
// If provided, the server should return results starting after this cursor.
Cursor Cursor `json:"cursor,omitempty"`
} `json:"params,omitempty"`
}
type PaginatedResult struct {
jsonrpc.Result
// An opaque token representing the pagination position after the last returned result.
// If present, there may be more results available.
NextCursor Cursor `json:"nextCursor,omitempty"`
}
/* Tools */
// Sent from the client to request a list of tools the server has.
type ListToolsRequest struct {
PaginatedRequest
}
// The server's response to a tools/list request from the client.
type ListToolsResult struct {
PaginatedResult
Tools []tools.McpManifest `json:"tools"`
}
// Used by the client to invoke a tool provided by the server.
type CallToolRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params,omitempty"`
}
// The sender or recipient of messages and data in a conversation.
type Role string
const (
RoleUser Role = "user"
RoleAssistant Role = "assistant"
)
// Base for objects that include optional annotations for the client.
// The client can use annotations to inform how objects are used or displayed
type Annotated struct {
Annotations *struct {
// Describes who the intended customer of this object or data is.
// It can include multiple entries to indicate content useful for multiple
// audiences (e.g., `["user", "assistant"]`).
Audience []Role `json:"audience,omitempty"`
// Describes how important this data is for operating the server.
//
// A value of 1 means "most important," and indicates that the data is
// effectively required, while 0 means "least important," and indicates that
// the data is entirely optional.
//
// @TJS-type number
// @minimum 0
// @maximum 1
Priority float64 `json:"priority,omitempty"`
} `json:"annotations,omitempty"`
}
// TextContent represents text provided to or from an LLM.
type TextContent struct {
Annotated
Type string `json:"type"`
// The text content of the message.
Text string `json:"text"`
}
// The server's response to a tool call.
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
type CallToolResult struct {
jsonrpc.Result
// Could be either a TextContent, ImageContent, or EmbeddedResources
// For Toolbox, we will only be sending TextContent
Content []TextContent `json:"content"`
// Whether the tool call ended in an error.
// If not set, this is assumed to be false (the call was successful).
//
// Any errors that originate from the tool SHOULD be reported inside the result
// object, with `isError` set to true, _not_ as an MCP protocol-level error
// response. Otherwise, the LLM would not be able to see that an error occurred
// and self-correct.
//
// However, any errors in _finding_ the tool, an error indicating that the
// server does not support tool calls, or any other exceptional conditions,
// should be reported as an MCP error response.
IsError bool `json:"isError,omitempty"`
// An optional JSON object that represents the structured result of the tool call.
StructuredContent map[string]any `json:"structuredContent,omitempty"`
}
// Additional properties describing a Tool to clients.
//
// NOTE: all properties in ToolAnnotations are **hints**.
// They are not guaranteed to provide a faithful description of
// tool behavior (including descriptive properties like `title`).
//
// Clients should never make tool use decisions based on ToolAnnotations
// received from untrusted servers.
type ToolAnnotations struct {
// A human-readable title for the tool.
Title string `json:"title,omitempty"`
// If true, the tool does not modify its environment.
// Default: false
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
// If true, the tool may perform destructive updates to its environment.
// If false, the tool performs only additive updates.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: true
DestructiveHint bool `json:"destructiveHint,omitempty"`
// If true, calling the tool repeatedly with the same arguments
// will have no additional effect on the its environment.
// (This property is meaningful only when `readOnlyHint == false`)
// Default: false
IdempotentHint bool `json:"idempotentHint,omitempty"`
// If true, this tool may interact with an "open world" of external
// entities. If false, the tool's domain of interaction is closed.
// For example, the world of a web search tool is open, whereas that
// of a memory tool is not.
// Default: true
OpenWorldHint bool `json:"openWorldHint,omitempty"`
}
/* Prompts */
// Sent from the client to request a list of prompts the server has.
type ListPromptsRequest struct {
PaginatedRequest
}
// The server's response to a prompts/list request from the client.
type ListPromptsResult struct {
PaginatedResult
Prompts []prompts.McpManifest `json:"prompts"`
}
// Used by the client to get a prompt provided by the server.
type GetPromptRequest struct {
jsonrpc.Request
Params struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
} `json:"params"`
}
// The server's response to a prompts/get request from the client.
type GetPromptResult struct {
jsonrpc.Result
Description string `json:"description,omitempty"`
Messages []PromptMessage `json:"messages"`
}
// Describes a message returned as part of a prompt.
type PromptMessage struct {
Role string `json:"role"`
Content TextContent `json:"content"`
}

View File

@@ -37,7 +37,6 @@ const jsonrpcVersion = "2.0"
const protocolVersion20241105 = "2024-11-05"
const protocolVersion20250326 = "2025-03-26"
const protocolVersion20250618 = "2025-06-18"
const protocolVersion20251125 = "2025-11-25"
const serverName = "Toolbox"
var basicInputSchema = map[string]any{
@@ -486,23 +485,6 @@ func TestMcpEndpoint(t *testing.T) {
},
},
},
{
name: "version 2025-11-25",
protocol: protocolVersion20251125,
idHeader: false,
initWant: map[string]any{
"jsonrpc": "2.0",
"id": "mcp-initialize",
"result": map[string]any{
"protocolVersion": "2025-11-25",
"capabilities": map[string]any{
"tools": map[string]any{"listChanged": false},
"prompts": map[string]any{"listChanged": false},
},
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
},
},
},
}
for _, vtc := range versTestCases {
t.Run(vtc.name, func(t *testing.T) {
@@ -512,7 +494,8 @@ func TestMcpEndpoint(t *testing.T) {
if sessionId != "" {
header["Mcp-Session-Id"] = sessionId
}
if vtc.protocol != protocolVersion20241105 && vtc.protocol != protocolVersion20250326 {
if vtc.protocol == protocolVersion20250618 {
header["MCP-Protocol-Version"] = vtc.protocol
}

View File

@@ -304,14 +304,10 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasWildcard := allowedHosts["*"]
hostname := r.Host
if host, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = host
}
_, hostIsAllowed := allowedHosts[hostname]
_, hostIsAllowed := allowedHosts[r.Host]
if !hasWildcard && !hostIsAllowed {
// Return 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusForbidden)
// Return 400 Bad Request or 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusBadRequest)
return
}
next.ServeHTTP(w, r)
@@ -410,11 +406,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
}
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
for _, h := range cfg.AllowedHosts {
hostname := h
if host, _, err := net.SplitHostPort(h); err == nil {
hostname = host
}
allowedHostsMap[hostname] = struct{}{}
allowedHostsMap[h] = struct{}{}
}
r.Use(hostCheck(allowedHostsMap))

View File

@@ -14,23 +14,20 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -67,29 +64,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
if !r.UseClientOAuth {
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}
return s, nil
}
@@ -97,8 +84,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
BaseURL string
Client *geminidataanalytics.DataChatClient
userAgent string
}
@@ -114,63 +100,34 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
return nil, err
}
defer cleanup()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
return client.QueryData(ctx, req)
}
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
client, err := geminidataanalytics.NewDataChatClient(ctx,
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
}
return s.Client, func() {}, nil
}

View File

@@ -181,11 +181,9 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
}
// Test UseClientAuthorization method
@@ -195,15 +193,16 @@ func TestInitialize(t *testing.T) {
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
_, _, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}

View File

@@ -198,7 +198,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -204,7 +204,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -209,7 +209,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -180,7 +180,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -184,7 +184,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -184,7 +184,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -174,7 +174,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -179,7 +179,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -179,7 +179,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -303,7 +303,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -175,7 +175,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -150,7 +150,3 @@ var _ tools.Tool = Tool{}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -19,11 +19,13 @@ import (
"encoding/json"
"fmt"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind string = "cloud-gemini-data-analytics-query"
@@ -60,7 +62,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, []byte) (any, error)
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal context from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err)
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
}
type Config struct {
@@ -97,12 +141,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
return Tool{
t := Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
return t, nil
}
// validate interface
@@ -145,18 +191,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: query,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
if t.Context != nil {
req.Context = t.Context.QueryDataContext
}
return source.RunQuery(ctx, tokenStr, bodyBytes)
if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}
return source.RunQuery(ctx, tokenStr, req)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
@@ -190,7 +238,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -16,19 +16,16 @@ package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
@@ -74,23 +71,29 @@ func TestParseFromYaml(t *testing.T) {
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
},
},
@@ -108,68 +111,63 @@ func TestParseFromYaml(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
// fakeSource implements the compatibleSource interface for testing.
type fakeSource struct {
projectID string
useClientOAuth bool
expectedQuery string
expectedParent string
response *geminidataanalyticspb.QueryDataResponse
}
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
return rt.Next.RoundTrip(&newReq)
func (f *fakeSource) GetProjectID() string {
return f.projectID
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
func (f *fakeSource) UseClientAuthorization() bool {
return f.useClientOAuth
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
func (f *fakeSource) SourceKind() string {
return "fake-gda-source"
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
func (f *fakeSource) ToConfig() sources.SourceConfig {
return nil
}
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return f, nil
}
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
if req.Prompt != f.expectedQuery {
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
}
if req.Parent != f.expectedParent {
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
}
// Basic validation of context/options could be added here if needed,
// but the test case mainly checks if they are passed correctly via successful invocation.
return f.response, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
// Minimal fake source
fake := &fakeSource{projectID: "test-project"}
srcs := map[string]sources.Source{
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
"gda-api-source": fake,
}
tcs := []struct {
@@ -188,9 +186,6 @@ func TestInitialize(t *testing.T) {
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
@@ -207,92 +202,27 @@ func TestInitialize(t *testing.T) {
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
projectID := "test-project"
location := "us-central1"
query := "How many accounts who have region in Prague are eligible for loans?"
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
// Prepare expected response
expectedResp := &geminidataanalyticspb.QueryDataResponse{
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
}
authClient := &http.Client{Transport: authTransport}
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
fake := &fakeSource{
projectID: projectID,
expectedQuery: query,
expectedParent: expectedParent,
response: expectedResp,
}
srcs := map[string]sources.Source{
"mock-gda-source": mockGdaSource,
"mock-gda-source": fake,
}
// Initialize the tool config with context
@@ -301,25 +231,31 @@ func TestInvoke(t *testing.T) {
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: "us-central1", // Set location for the test
Location: location,
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
}
@@ -330,24 +266,25 @@ func TestInvoke(t *testing.T) {
// Prepare parameters for invocation - ONLY query
params := parameters.ParamValues{
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
{Name: "query", Value: query},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
result, err := tool.Invoke(ctx, resourceMgr, params, "")
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
if !ok {
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
}
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -1,116 +0,0 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -142,7 +142,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -193,7 +193,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -266,7 +266,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -133,7 +133,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,7 +154,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,7 +154,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -168,7 +168,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,7 +154,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -154,7 +154,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -133,7 +133,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -181,7 +181,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -207,7 +207,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -192,7 +192,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -176,7 +176,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -187,7 +187,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -178,7 +178,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -175,7 +175,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -180,7 +180,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -171,7 +171,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -170,7 +170,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -165,7 +165,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -300,7 +300,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -204,7 +204,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -206,7 +206,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -142,7 +142,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -130,7 +130,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -153,7 +153,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -138,7 +138,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -157,7 +157,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -287,7 +287,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -139,7 +139,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -151,7 +151,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -147,7 +147,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -152,7 +152,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -206,7 +206,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -169,7 +169,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -150,7 +150,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.PayloadParams
}

View File

@@ -148,7 +148,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.PayloadParams
}

View File

@@ -158,7 +158,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -159,7 +159,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,7 +141,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -399,7 +399,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -157,7 +157,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,7 +141,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -166,7 +166,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -234,7 +234,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -180,7 +180,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -303,7 +303,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -170,7 +170,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -145,7 +145,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -153,7 +153,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -165,7 +165,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -166,7 +166,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -20,7 +20,7 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.comcom/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
@@ -85,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
Parameters: params,
}
return t, nil
}
@@ -97,7 +96,6 @@ type Tool struct {
Config
manifest tools.Manifest
mcpManifest tools.McpManifest
Parameters parameters.Parameters
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -109,11 +107,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)
return parameters.ParamValues{}, nil
}
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
return parameters.EmbedParams(ctx, t.Parameters, paramValues, embeddingModelsMap, nil)
return parameters.ParamValues{}, nil
}
func (t Tool) Manifest() tools.Manifest {
@@ -139,7 +137,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -235,7 +235,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -164,7 +164,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -161,7 +161,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -175,7 +175,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -164,7 +164,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -186,7 +186,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -198,7 +198,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -212,7 +212,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -173,7 +173,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.allParams
}

View File

@@ -126,7 +126,3 @@ func (t *Tool) ToConfig() tools.ToolConfig {
func (t *Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t *Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -147,7 +147,3 @@ func (t *Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t *Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -146,7 +146,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -151,7 +151,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -143,10 +143,6 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
return false, nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.sourceProvider) (string, error) {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -174,7 +174,3 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -143,7 +143,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -163,10 +163,6 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string,
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}
// GoogleSQL statement for listing graphs
const googleSQLStatement = `
WITH FilterGraphNames AS (

View File

@@ -189,10 +189,6 @@ func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string,
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}
// PostgreSQL statement for listing tables
const postgresqlStatement = `
WITH table_info_cte AS (

View File

@@ -185,7 +185,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

View File

@@ -141,7 +141,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.Parameters
}

View File

@@ -145,7 +145,3 @@ func (t Tool) ToConfig() tools.ToolConfig {
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}
func (t Tool) GetParameters() parameters.Parameters {
return t.AllParams
}

Some files were not shown because too many files have changed in this diff Show More