mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-05 04:35:14 -05:00
Compare commits
19 Commits
spanner-cr
...
err-api
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ee0491736 | ||
|
|
c733f03717 | ||
|
|
ca082d1566 | ||
|
|
d543381f38 | ||
|
|
f8ea21b428 | ||
|
|
15d2dafdde | ||
|
|
60b768c8ba | ||
|
|
e73768c4db | ||
|
|
1f9cd1b134 | ||
|
|
fc5d3ef805 | ||
|
|
da2c103234 | ||
|
|
0c5285c5c8 | ||
|
|
ac544d0878 | ||
|
|
54f9a3d312 | ||
|
|
62d96a662d | ||
|
|
46244458c4 | ||
|
|
b6fa798610 | ||
|
|
bb58baff70 | ||
|
|
32b2c9366d |
@@ -376,26 +376,6 @@ steps:
|
|||||||
spanner \
|
spanner \
|
||||||
spanner || echo "Integration tests failed." # ignore test failures
|
spanner || echo "Integration tests failed." # ignore test failures
|
||||||
|
|
||||||
- id: "spanner-admin"
|
|
||||||
name: golang:1
|
|
||||||
waitFor: ["compile-test-binary"]
|
|
||||||
entrypoint: /bin/bash
|
|
||||||
env:
|
|
||||||
- "GOPATH=/gopath"
|
|
||||||
- "SPANNER_PROJECT=$PROJECT_ID"
|
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
|
||||||
secretEnv: ["CLIENT_ID"]
|
|
||||||
volumes:
|
|
||||||
- name: "go"
|
|
||||||
path: "/gopath"
|
|
||||||
args:
|
|
||||||
- -c
|
|
||||||
- |
|
|
||||||
.ci/test_with_coverage.sh \
|
|
||||||
"Spanner Admin" \
|
|
||||||
spanneradmin \
|
|
||||||
spanneradmin || echo "Integration tests failed."
|
|
||||||
|
|
||||||
- id: "neo4j"
|
- id: "neo4j"
|
||||||
name: golang:1
|
name: golang:1
|
||||||
waitFor: ["compile-test-binary"]
|
waitFor: ["compile-test-binary"]
|
||||||
|
|||||||
@@ -228,7 +228,6 @@ import (
|
|||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanneradmin/spannercreateinstance"
|
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
||||||
@@ -276,7 +275,6 @@ import (
|
|||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/singlestore"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/singlestore"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/snowflake"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/snowflake"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/spanneradmin"
|
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/sources/trino"
|
_ "github.com/googleapis/genai-toolbox/internal/sources/trino"
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
# Cloud Spanner Admin MCP Server
|
|
||||||
|
|
||||||
The Cloud Spanner Admin Model Context Protocol (MCP) Server gives AI-powered development tools the ability to manage your Google Cloud Spanner infrastructure. It supports creating instances.
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
An editor configured to use the Cloud Spanner Admin MCP server can use its AI capabilities to help you:
|
|
||||||
|
|
||||||
- **Provision & Manage Infrastructure** - Create Cloud Spanner instances
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
* [Node.js](https://nodejs.org/) installed.
|
|
||||||
* A Google Cloud project with the **Cloud Spanner Admin API** enabled.
|
|
||||||
* Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment.
|
|
||||||
* IAM Permissions:
|
|
||||||
* Cloud Spanner Admin (`roles/spanner.admin`)
|
|
||||||
|
|
||||||
## Install & Configuration
|
|
||||||
|
|
||||||
In the Antigravity MCP Store, click the "Install" button.
|
|
||||||
|
|
||||||
You'll now be able to see all enabled tools in the "Tools" tab.
|
|
||||||
|
|
||||||
> [!NOTE]
|
|
||||||
> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Once configured, the MCP server will automatically provide Cloud Spanner Admin capabilities to your AI assistant. You can:
|
|
||||||
|
|
||||||
* "Create a new Spanner instance named 'my-spanner-instance' in the 'my-gcp-project' project with config 'regional-us-central1', edition 'ENTERPRISE', and 1 node."
|
|
||||||
|
|
||||||
## Server Capabilities
|
|
||||||
|
|
||||||
The Cloud Spanner Admin MCP server provides the following tools:
|
|
||||||
|
|
||||||
| Tool Name | Description |
|
|
||||||
|:------------------|:---------------------------------|
|
|
||||||
| `create_instance` | Create a Cloud Spanner instance. |
|
|
||||||
|
|
||||||
## Custom MCP Server Configuration
|
|
||||||
|
|
||||||
Add the following configuration to your MCP client (e.g., `settings.json` for Gemini CLI, `mcp_config.json` for Antigravity):
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"mcpServers": {
|
|
||||||
"spanner-admin": {
|
|
||||||
"command": "npx",
|
|
||||||
"args": ["-y", "@toolbox-sdk/server", "--prebuilt", "spanner-admin", "--stdio"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Documentation
|
|
||||||
|
|
||||||
For more information, visit the [Cloud Spanner Admin API documentation](https://cloud.google.com/spanner/docs/reference/rpc/google.spanner.admin.instance.v1).
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
---
|
|
||||||
title: Spanner Admin
|
|
||||||
type: docs
|
|
||||||
weight: 1
|
|
||||||
description: "A \"spanner-admin\" source provides a client for the Cloud Spanner Admin API.\n"
|
|
||||||
alias: [/resources/sources/spanner-admin]
|
|
||||||
---
|
|
||||||
|
|
||||||
## About
|
|
||||||
|
|
||||||
The `spanner-admin` source provides a client to interact with the [Google
|
|
||||||
Cloud Spanner Admin API](https://cloud.google.com/spanner/docs/reference/rpc/google.spanner.admin.instance.v1). This
|
|
||||||
allows tools to perform administrative tasks on Spanner instances, such as
|
|
||||||
creating instances.
|
|
||||||
|
|
||||||
Authentication can be handled in two ways:
|
|
||||||
|
|
||||||
1. **Application Default Credentials (ADC):** By default, the source uses ADC
|
|
||||||
to authenticate with the API.
|
|
||||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source will
|
|
||||||
expect an OAuth 2.0 access token to be provided by the client (e.g., a web
|
|
||||||
browser) for each request.
|
|
||||||
|
|
||||||
## Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin
|
|
||||||
type: spanner-admin
|
|
||||||
---
|
|
||||||
kind: sources
|
|
||||||
name: my-oauth-spanner-admin
|
|
||||||
type: spanner-admin
|
|
||||||
useClientOAuth: true
|
|
||||||
```
|
|
||||||
|
|
||||||
## Reference
|
|
||||||
|
|
||||||
| **field** | **type** | **required** | **description** |
|
|
||||||
| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------- |
|
|
||||||
| type | string | true | Must be "spanner-admin". |
|
|
||||||
| defaultProject | string | false | The Google Cloud project ID to use for Spanner infrastructure tools. |
|
|
||||||
| useClientOAuth | boolean | false | If true, the source will use client-side OAuth for authorization. Otherwise, it will use Application Default Credentials. Defaults to `false`. |
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
---
|
|
||||||
title: spanner-create-instance
|
|
||||||
type: docs
|
|
||||||
weight: 2
|
|
||||||
description: "Create a Cloud Spanner instance."
|
|
||||||
---
|
|
||||||
|
|
||||||
The `spanner-create-instance` tool creates a new Cloud Spanner instance in a
|
|
||||||
specified Google Cloud project.
|
|
||||||
|
|
||||||
{{< notice info >}}
|
|
||||||
This tool uses the `spanner-admin` source.
|
|
||||||
{{< /notice >}}
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Here is an example of how to configure the `spanner-create-instance` tool in
|
|
||||||
your `tools.yaml` file:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin-source
|
|
||||||
type: spanner-admin
|
|
||||||
---
|
|
||||||
kind: tools
|
|
||||||
name: create_my_spanner_instance
|
|
||||||
type: spanner-create-instance
|
|
||||||
source: my-spanner-admin-source
|
|
||||||
description: "Creates a Spanner instance."
|
|
||||||
```
|
|
||||||
|
|
||||||
## Parameters
|
|
||||||
|
|
||||||
The `spanner-create-instance` tool has the following parameters:
|
|
||||||
|
|
||||||
| **field** | **type** | **required** | **description** |
|
|
||||||
| --------------- | :------: | :----------: | ------------------------------------------------------------------------------------ |
|
|
||||||
| project | string | true | The Google Cloud project ID. |
|
|
||||||
| instanceId | string | true | The ID of the instance to create. |
|
|
||||||
| displayName | string | true | The display name of the instance. |
|
|
||||||
| config | string | true | The instance configuration (e.g., `regional-us-central1`). |
|
|
||||||
| nodeCount | integer | true | The number of nodes. Mutually exclusive with `processingUnits` (one must be 0). |
|
|
||||||
| processingUnits | integer | true | The number of processing units. Mutually exclusive with `nodeCount` (one must be 0). |
|
|
||||||
| edition | string | false | The edition of the instance (`STANDARD`, `ENTERPRISE`, `ENTERPRISE_PLUS`). |
|
|
||||||
|
|
||||||
## Reference
|
|
||||||
|
|
||||||
| **field** | **type** | **required** | **description** |
|
|
||||||
| ----------- | :------: | :----------: | ------------------------------------------------------------ |
|
|
||||||
| type | string | true | Must be `spanner-create-instance`. |
|
|
||||||
| source | string | true | The name of the `spanner-admin` source to use for this tool. |
|
|
||||||
| description | string | false | A description of the tool that is passed to the agent. |
|
|
||||||
@@ -51,7 +51,6 @@ var expectedToolSources = []string{
|
|||||||
"serverless-spark",
|
"serverless-spark",
|
||||||
"singlestore",
|
"singlestore",
|
||||||
"snowflake",
|
"snowflake",
|
||||||
"spanner-admin",
|
|
||||||
"spanner-postgres",
|
"spanner-postgres",
|
||||||
"spanner",
|
"spanner",
|
||||||
"sqlite",
|
"sqlite",
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
# Copyright 2026 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
sources:
|
|
||||||
spanner-admin-source:
|
|
||||||
kind: spanner-admin
|
|
||||||
defaultProject: ${SPANNER_PROJECT:}
|
|
||||||
|
|
||||||
tools:
|
|
||||||
create_instance:
|
|
||||||
kind: spanner-create-instance
|
|
||||||
source: spanner-admin-source
|
|
||||||
|
|
||||||
toolsets:
|
|
||||||
spanner_admin_tools:
|
|
||||||
- create_instance
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -234,10 +233,11 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401
|
// If auth error, return 401 or 403
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
var clientServerErr *util.ClientServerError
|
||||||
|
if errors.As(err, &clientServerErr) && (clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden) {
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
_ = render.Render(w, r, newErrResponse(err, clientServerErr.Code))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||||
@@ -259,35 +259,50 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
var statusCode int
|
|
||||||
|
|
||||||
// Upstream API auth error propagation
|
if errors.As(err, &tbErr) {
|
||||||
switch {
|
switch tbErr.Category() {
|
||||||
case strings.Contains(errStr, "Error 401"):
|
case util.CategoryAgent:
|
||||||
statusCode = http.StatusUnauthorized
|
// Agent Errors -> 200 OK
|
||||||
case strings.Contains(errStr, "Error 403"):
|
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
||||||
statusCode = http.StatusForbidden
|
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
||||||
|
return
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// Server Errors -> Check the specific code inside
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
statusCode := http.StatusInternalServerError // Default to 500
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code != 0 {
|
||||||
|
statusCode = clientServerErr.Code
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process auth error
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
// Propagate the original 401/403 error.
|
// Token error, pass through 401/403
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// ADC lacking permission or credentials configuration error.
|
// ADC/Config error, return 500
|
||||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
statusCode = http.StatusInternalServerError
|
||||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
}
|
||||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
|
||||||
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
} else {
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
// Unknown error -> 500
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -444,15 +443,12 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
|
// Map Internal RPC Error (-32603) to HTTP 500
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
errStr := err.Error()
|
var clientServerErr *util.ClientServerError
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &clientServerErr) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(clientServerErr.Code)
|
||||||
} else if strings.Contains(errStr, "Error 401") {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
} else if strings.Contains(errStr, "Error 403") {
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -124,7 +123,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
errMsg := "missing access token in the 'Authorization' header"
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||||
|
errMsg,
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +176,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,21 +202,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Upstream auth error
|
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if errors.As(err, &tbErr) {
|
||||||
|
switch tbErr.Category() {
|
||||||
|
case util.CategoryAgent:
|
||||||
|
// MCP - Tool execution error
|
||||||
|
// Return SUCCESS but with IsError: true
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -218,6 +218,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -124,7 +123,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
errMsg := "missing access token in the 'Authorization' header"
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||||
|
errMsg,
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +176,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,20 +202,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -217,8 +218,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
}
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -117,7 +116,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
errMsg := "missing access token in the 'Authorization' header"
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||||
|
errMsg,
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +169,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +195,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +211,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -117,7 +116,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
errMsg := "missing access token in the 'Authorization' header"
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
||||||
|
errMsg,
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +169,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +195,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +211,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,120 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package spanneradmin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
instance "cloud.google.com/go/spanner/admin/instance/apiv1"
|
|
||||||
"github.com/goccy/go-yaml"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
|
||||||
"go.opentelemetry.io/otel/trace"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"google.golang.org/api/option"
|
|
||||||
)
|
|
||||||
|
|
||||||
const SourceType string = "spanner-admin"
|
|
||||||
|
|
||||||
// validate interface
|
|
||||||
var _ sources.SourceConfig = Config{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if !sources.Register(SourceType, newConfig) {
|
|
||||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
|
||||||
actual := Config{Name: name}
|
|
||||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Name string `yaml:"name" validate:"required"`
|
|
||||||
Type string `yaml:"type" validate:"required"`
|
|
||||||
DefaultProject string `yaml:"defaultProject"`
|
|
||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r Config) SourceConfigType() string {
|
|
||||||
return SourceType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize initializes a Spanner Admin Source instance.
|
|
||||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
|
||||||
var client *instance.InstanceAdminClient
|
|
||||||
|
|
||||||
if !r.UseClientOAuth {
|
|
||||||
ua, err := util.UserAgentFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
|
||||||
}
|
|
||||||
// Use Application Default Credentials
|
|
||||||
client, err = instance.NewInstanceAdminClient(ctx, option.WithUserAgent(ua))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating new spanner instance admin client: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Source{
|
|
||||||
Config: r,
|
|
||||||
Client: client,
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ sources.Source = &Source{}
|
|
||||||
|
|
||||||
type Source struct {
|
|
||||||
Config
|
|
||||||
Client *instance.InstanceAdminClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) SourceType() string {
|
|
||||||
return SourceType
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) ToConfig() sources.SourceConfig {
|
|
||||||
return s.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) GetDefaultProject() string {
|
|
||||||
return s.DefaultProject
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*instance.InstanceAdminClient, error) {
|
|
||||||
if s.UseClientOAuth {
|
|
||||||
token := &oauth2.Token{AccessToken: accessToken}
|
|
||||||
ua, err := util.UserAgentFromContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
client, err := instance.NewInstanceAdminClient(ctx, option.WithTokenSource(oauth2.StaticTokenSource(token)), option.WithUserAgent(ua))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating new spanner instance admin client: %w", err)
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
return s.Client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) UseClientAuthorization() bool {
|
|
||||||
return s.UseClientOAuth
|
|
||||||
}
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package spanneradmin_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/spanneradmin"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseFromYamlSpannerAdmin(t *testing.T) {
|
|
||||||
tcs := []struct {
|
|
||||||
desc string
|
|
||||||
in string
|
|
||||||
want server.SourceConfigs
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "basic example",
|
|
||||||
in: `
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin-instance
|
|
||||||
type: spanner-admin
|
|
||||||
`,
|
|
||||||
want: map[string]sources.SourceConfig{
|
|
||||||
"my-spanner-admin-instance": spanneradmin.Config{
|
|
||||||
Name: "my-spanner-admin-instance",
|
|
||||||
Type: spanneradmin.SourceType,
|
|
||||||
UseClientOAuth: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "use client auth example",
|
|
||||||
in: `
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin-instance
|
|
||||||
type: spanner-admin
|
|
||||||
useClientOAuth: true
|
|
||||||
`,
|
|
||||||
want: map[string]sources.SourceConfig{
|
|
||||||
"my-spanner-admin-instance": spanneradmin.Config{
|
|
||||||
Name: "my-spanner-admin-instance",
|
|
||||||
Type: spanneradmin.SourceType,
|
|
||||||
UseClientOAuth: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tcs {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
got, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to unmarshal: %s", err)
|
|
||||||
}
|
|
||||||
if !cmp.Equal(tc.want, got) {
|
|
||||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFailParseFromYaml(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tcs := []struct {
|
|
||||||
desc string
|
|
||||||
in string
|
|
||||||
err string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "extra field",
|
|
||||||
in: `
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin-instance
|
|
||||||
type: spanner-admin
|
|
||||||
project: test-project
|
|
||||||
`,
|
|
||||||
err: `error unmarshaling sources: unable to parse source "my-spanner-admin-instance" as "spanner-admin": [2:1] unknown field "project"
|
|
||||||
1 | name: my-spanner-admin-instance
|
|
||||||
> 2 | project: test-project
|
|
||||||
^
|
|
||||||
3 | type: spanner-admin`,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "missing required field",
|
|
||||||
in: `
|
|
||||||
kind: sources
|
|
||||||
name: my-spanner-admin-instance
|
|
||||||
useClientOAuth: true
|
|
||||||
`,
|
|
||||||
err: "error unmarshaling sources: missing 'type' field or it is not a string",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tcs {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
_, _, _, _, _, _, err := server.UnmarshalResourceConfig(context.Background(), testutils.FormatYaml(tc.in))
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expect parsing to fail")
|
|
||||||
}
|
|
||||||
errStr := err.Error()
|
|
||||||
if errStr != tc.err {
|
|
||||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package spannercreateinstance
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
instance "cloud.google.com/go/spanner/admin/instance/apiv1"
|
|
||||||
"cloud.google.com/go/spanner/admin/instance/apiv1/instancepb"
|
|
||||||
"github.com/goccy/go-yaml"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
const resourceType string = "spanner-create-instance"
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if !tools.Register(resourceType, newConfig) {
|
|
||||||
panic(fmt.Sprintf("tool type %q already registered", resourceType))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
|
||||||
actual := Config{Name: name}
|
|
||||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return actual, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type compatibleSource interface {
|
|
||||||
GetDefaultProject() string
|
|
||||||
GetClient(context.Context, string) (*instance.InstanceAdminClient, error)
|
|
||||||
UseClientAuthorization() bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config defines the configuration for the create-instance tool.
|
|
||||||
type Config struct {
|
|
||||||
Name string `yaml:"name" validate:"required"`
|
|
||||||
Type string `yaml:"type" validate:"required"`
|
|
||||||
Description string `yaml:"description"`
|
|
||||||
Source string `yaml:"source" validate:"required"`
|
|
||||||
AuthRequired []string `yaml:"authRequired"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate interface
|
|
||||||
var _ tools.ToolConfig = Config{}
|
|
||||||
|
|
||||||
// ToolConfigKind returns the kind of the tool.
|
|
||||||
func (cfg Config) ToolConfigType() string {
|
|
||||||
return resourceType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize initializes the tool from the configuration.
|
|
||||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
|
||||||
rawS, ok := srcs[cfg.Source]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
|
||||||
}
|
|
||||||
s, ok := rawS.(compatibleSource)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", resourceType, cfg.Source)
|
|
||||||
}
|
|
||||||
|
|
||||||
project := s.GetDefaultProject()
|
|
||||||
var projectParam parameters.Parameter
|
|
||||||
if project != "" {
|
|
||||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID.")
|
|
||||||
} else {
|
|
||||||
projectParam = parameters.NewStringParameter("project", "The project ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
allParameters := parameters.Parameters{
|
|
||||||
projectParam,
|
|
||||||
parameters.NewStringParameter("instanceId", "The ID of the instance"),
|
|
||||||
parameters.NewStringParameter("displayName", "The display name of the instance"),
|
|
||||||
parameters.NewStringParameter("config", "The instance configuration (e.g., regional-us-central1)"),
|
|
||||||
parameters.NewIntParameter("nodeCount", "The number of nodes, mutually exclusive with processingUnits (one must be 0)"),
|
|
||||||
parameters.NewIntParameter("processingUnits", "The number of processing units, mutually exclusive with nodeCount (one must be 0)"),
|
|
||||||
parameters.NewStringParameter("edition", "The edition of the instance (STANDARD, ENTERPRISE, ENTERPRISE_PLUS)"),
|
|
||||||
}
|
|
||||||
|
|
||||||
paramManifest := allParameters.Manifest()
|
|
||||||
|
|
||||||
description := cfg.Description
|
|
||||||
if description == "" {
|
|
||||||
description = "Creates a Spanner instance."
|
|
||||||
}
|
|
||||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
|
||||||
|
|
||||||
return Tool{
|
|
||||||
Config: cfg,
|
|
||||||
AllParams: allParameters,
|
|
||||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
|
||||||
mcpManifest: mcpManifest,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tool represents the create-instance tool.
|
|
||||||
type Tool struct {
|
|
||||||
Config
|
|
||||||
AllParams parameters.Parameters
|
|
||||||
manifest tools.Manifest
|
|
||||||
mcpManifest tools.McpManifest
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) ToConfig() tools.ToolConfig {
|
|
||||||
return t.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
paramsMap := params.AsMap()
|
|
||||||
|
|
||||||
project, _ := paramsMap["project"].(string)
|
|
||||||
instanceId, _ := paramsMap["instanceId"].(string)
|
|
||||||
displayName, _ := paramsMap["displayName"].(string)
|
|
||||||
config, _ := paramsMap["config"].(string)
|
|
||||||
nodeCount, _ := paramsMap["nodeCount"].(int)
|
|
||||||
processingUnits, _ := paramsMap["processingUnits"].(int)
|
|
||||||
editionStr, _ := paramsMap["edition"].(string)
|
|
||||||
|
|
||||||
if (nodeCount > 0 && processingUnits > 0) || (nodeCount == 0 && processingUnits == 0) {
|
|
||||||
return nil, fmt.Errorf("one of nodeCount or processingUnits must be positive, and the other must be 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := source.GetClient(ctx, string(accessToken))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
defer client.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
parent := fmt.Sprintf("projects/%s", project)
|
|
||||||
instanceConfig := fmt.Sprintf("projects/%s/instanceConfigs/%s", project, config)
|
|
||||||
|
|
||||||
var edition instancepb.Instance_Edition
|
|
||||||
switch editionStr {
|
|
||||||
case "STANDARD":
|
|
||||||
edition = instancepb.Instance_STANDARD
|
|
||||||
case "ENTERPRISE":
|
|
||||||
edition = instancepb.Instance_ENTERPRISE
|
|
||||||
case "ENTERPRISE_PLUS":
|
|
||||||
edition = instancepb.Instance_ENTERPRISE_PLUS
|
|
||||||
default:
|
|
||||||
edition = instancepb.Instance_EDITION_UNSPECIFIED
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct the instance object
|
|
||||||
instance := &instancepb.Instance{
|
|
||||||
Config: instanceConfig,
|
|
||||||
DisplayName: displayName,
|
|
||||||
Edition: edition,
|
|
||||||
NodeCount: int32(nodeCount),
|
|
||||||
ProcessingUnits: int32(processingUnits),
|
|
||||||
}
|
|
||||||
|
|
||||||
req := &instancepb.CreateInstanceRequest{
|
|
||||||
Parent: parent,
|
|
||||||
InstanceId: instanceId,
|
|
||||||
Instance: instance,
|
|
||||||
}
|
|
||||||
|
|
||||||
op, err := client.CreateInstance(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create instance: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the operation to complete
|
|
||||||
resp, err := op.Wait(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to wait for create instance operation: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseParams parses the parameters for the tool.
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
|
||||||
return parameters.ParseParams(t.AllParams, data, claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
|
||||||
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manifest returns the tool's manifest.
|
|
||||||
func (t Tool) Manifest() tools.Manifest {
|
|
||||||
return t.manifest
|
|
||||||
}
|
|
||||||
|
|
||||||
// McpManifest returns the tool's MCP manifest.
|
|
||||||
func (t Tool) McpManifest() tools.McpManifest {
|
|
||||||
return t.mcpManifest
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authorized checks if the tool is authorized.
|
|
||||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return source.UseClientAuthorization(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
|
||||||
return "Authorization", nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) GetParameters() parameters.Parameters {
|
|
||||||
return t.AllParams
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package spannercreateinstance_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/spanneradmin/spannercreateinstance"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseFromYaml(t *testing.T) {
|
|
||||||
ctx, err := testutils.ContextWithNewLogger()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %s", err)
|
|
||||||
}
|
|
||||||
tcs := []struct {
|
|
||||||
desc string
|
|
||||||
in string
|
|
||||||
want server.ToolConfigs
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "basic example",
|
|
||||||
in: `
|
|
||||||
kind: tools
|
|
||||||
name: create-instance-tool
|
|
||||||
type: spanner-create-instance
|
|
||||||
description: a test description
|
|
||||||
source: a-source
|
|
||||||
`,
|
|
||||||
want: server.ToolConfigs{
|
|
||||||
"create-instance-tool": spannercreateinstance.Config{
|
|
||||||
Name: "create-instance-tool",
|
|
||||||
Type: "spanner-create-instance",
|
|
||||||
Description: "a test description",
|
|
||||||
Source: "a-source",
|
|
||||||
AuthRequired: []string{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tc := range tcs {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
// Parse contents
|
|
||||||
_, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to unmarshal: %s", err)
|
|
||||||
}
|
|
||||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
|
||||||
t.Fatalf("incorrect parse: diff %v", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvokeNodeCountAndProcessingUnitsValidation(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
params parameters.ParamValues
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Both positive",
|
|
||||||
params: parameters.ParamValues{
|
|
||||||
{Name: "nodeCount", Value: 1},
|
|
||||||
{Name: "processingUnits", Value: 1000},
|
|
||||||
},
|
|
||||||
wantErr: "one of nodeCount or processingUnits must be positive, and the other must be 0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Both zero",
|
|
||||||
params: parameters.ParamValues{
|
|
||||||
{Name: "nodeCount", Value: 0},
|
|
||||||
{Name: "processingUnits", Value: 0},
|
|
||||||
},
|
|
||||||
wantErr: "one of nodeCount or processingUnits must be positive, and the other must be 0",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
tool := spannercreateinstance.Tool{}
|
|
||||||
_, err := tool.Invoke(context.Background(), nil, tc.params, "")
|
|
||||||
if err == nil || err.Error() != tc.wantErr {
|
|
||||||
t.Errorf("Invoke() error = %v, wantErr %v", err, tc.wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -17,6 +17,7 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ type AccessToken string
|
|||||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||||
headerParts := strings.Split(string(token), " ")
|
headerParts := strings.Split(string(token), " ")
|
||||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||||
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
return headerParts[1], nil
|
return headerParts[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
77
internal/util/errors.go
Normal file
77
internal/util/errors.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
type ErrorCategory string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
||||||
|
CategoryServer ErrorCategory = "SERVER_ERROR"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToolboxError is the interface all custom errors must satisfy
|
||||||
|
type ToolboxError interface {
|
||||||
|
error
|
||||||
|
Category() ErrorCategory
|
||||||
|
Error() string
|
||||||
|
Unwrap() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Agent Errors return 200 to the sender
|
||||||
|
type AgentError struct {
|
||||||
|
Msg string
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ToolboxError = &AgentError{}
|
||||||
|
|
||||||
|
func (e *AgentError) Error() string {
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("%s: %v", e.Msg, e.Cause)
|
||||||
|
}
|
||||||
|
return e.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
||||||
|
|
||||||
|
func (e *AgentError) Unwrap() error { return e.Cause }
|
||||||
|
|
||||||
|
func NewAgentError(msg string, cause error) *AgentError {
|
||||||
|
return &AgentError{Msg: msg, Cause: cause}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientServerError returns 4XX/5XX error code
|
||||||
|
type ClientServerError struct {
|
||||||
|
Msg string
|
||||||
|
Code int
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ToolboxError = &ClientServerError{}
|
||||||
|
|
||||||
|
func (e *ClientServerError) Error() string {
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("%s: %v", e.Msg, e.Cause)
|
||||||
|
}
|
||||||
|
return e.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
||||||
|
|
||||||
|
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
||||||
|
|
||||||
|
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
||||||
|
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -118,7 +119,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -188,5 +187,3 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrUnauthorized = errors.New("unauthorized")
|
|
||||||
|
|||||||
@@ -1,178 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package spanneradmin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
instance "cloud.google.com/go/spanner/admin/instance/apiv1"
|
|
||||||
"cloud.google.com/go/spanner/admin/instance/apiv1/instancepb"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
SpannerProject = os.Getenv("SPANNER_PROJECT")
|
|
||||||
)
|
|
||||||
|
|
||||||
func getSpannerAdminVars(t *testing.T) map[string]any {
|
|
||||||
if SpannerProject == "" {
|
|
||||||
t.Fatal("'SPANNER_PROJECT' not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
return map[string]any{
|
|
||||||
"type": "spanner-admin",
|
|
||||||
"defaultProject": SpannerProject,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSpannerAdminCreateInstance(t *testing.T) {
|
|
||||||
sourceConfig := getSpannerAdminVars(t)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Minute)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
shortUuid := strings.ReplaceAll(uuid.New().String(), "-", "")[:10]
|
|
||||||
instanceId := "test-inst-" + shortUuid
|
|
||||||
|
|
||||||
displayName := "Test Instance " + shortUuid
|
|
||||||
instanceConfig := "regional-us-central1"
|
|
||||||
nodeCount := 1
|
|
||||||
edition := "ENTERPRISE"
|
|
||||||
|
|
||||||
// Setup Admin Client for verification and cleanup
|
|
||||||
adminClient, err := instance.NewInstanceAdminClient(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create Spanner instance admin client: %s", err)
|
|
||||||
}
|
|
||||||
defer adminClient.Close()
|
|
||||||
|
|
||||||
// Teardown function
|
|
||||||
defer func() {
|
|
||||||
err := adminClient.DeleteInstance(ctx, &instancepb.DeleteInstanceRequest{
|
|
||||||
Name: fmt.Sprintf("projects/%s/instances/%s", SpannerProject, instanceId),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
// If it fails, it might not have been created, log it but don't fail if it's "not found"
|
|
||||||
t.Logf("cleanup: failed to delete instance %s: %s", instanceId, err)
|
|
||||||
} else {
|
|
||||||
t.Logf("cleanup: deleted instance %s", instanceId)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Construct Tools Config
|
|
||||||
|
|
||||||
toolsConfig := map[string]any{
|
|
||||||
"sources": map[string]any{
|
|
||||||
"my-spanner-admin": sourceConfig,
|
|
||||||
},
|
|
||||||
"tools": map[string]any{
|
|
||||||
"create-instance-tool": map[string]any{
|
|
||||||
"type": "spanner-create-instance",
|
|
||||||
"source": "my-spanner-admin",
|
|
||||||
"description": "Creates a Spanner instance.",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start Toolbox Server
|
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("command initialization returned an error: %s", err)
|
|
||||||
}
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
waitCtx, cancelWait := context.WithTimeout(ctx, 10*time.Second)
|
|
||||||
defer cancelWait()
|
|
||||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("toolbox command logs: \n%s", out)
|
|
||||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare Invocation Payload
|
|
||||||
|
|
||||||
payload := map[string]any{
|
|
||||||
"project": SpannerProject,
|
|
||||||
"instanceId": instanceId,
|
|
||||||
"displayName": displayName,
|
|
||||||
"config": instanceConfig,
|
|
||||||
"nodeCount": nodeCount,
|
|
||||||
"edition": edition,
|
|
||||||
"processingUnits": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadBytes, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to marshal payload: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invoke Tool
|
|
||||||
invokeUrl := "http://127.0.0.1:5000/api/tool/create-instance-tool/invoke"
|
|
||||||
req, err := http.NewRequest(http.MethodPost, invokeUrl, bytes.NewBuffer(payloadBytes))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create request: %s", err)
|
|
||||||
}
|
|
||||||
req.Header.Add("Content-type", "application/json")
|
|
||||||
|
|
||||||
t.Logf("Invoking create-instance-tool for instance: %s", instanceId)
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send request: %s", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
||||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check Response
|
|
||||||
var body map[string]interface{}
|
|
||||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error parsing response body")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify Instance Exists via Admin Client
|
|
||||||
t.Logf("Verifying instance %s exists...", instanceId)
|
|
||||||
instanceName := fmt.Sprintf("projects/%s/instances/%s", SpannerProject, instanceId)
|
|
||||||
gotInstance, err := adminClient.GetInstance(ctx, &instancepb.GetInstanceRequest{
|
|
||||||
Name: instanceName,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get instance from admin client: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotInstance.Name != instanceName {
|
|
||||||
t.Errorf("expected instance name %s, got %s", instanceName, gotInstance.Name)
|
|
||||||
}
|
|
||||||
if gotInstance.DisplayName != displayName {
|
|
||||||
t.Errorf("expected display name %s, got %s", displayName, gotInstance.DisplayName)
|
|
||||||
}
|
|
||||||
if gotInstance.NodeCount != int32(nodeCount) {
|
|
||||||
t.Errorf("expected node count %d, got %d", nodeCount, gotInstance.NodeCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user