mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-04 20:25:05 -05:00
Compare commits
9 Commits
err-api
...
integratio
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dfb9a5378 | ||
|
|
7908c2256e | ||
|
|
a18ab29f4a | ||
|
|
feee91a18d | ||
|
|
468470098e | ||
|
|
394c3b78bc | ||
|
|
4874c12d3f | ||
|
|
d0d9b78b3b | ||
|
|
f6587cfaf8 |
4
.github/workflows/deploy_dev_docs.yaml
vendored
4
.github/workflows/deploy_dev_docs.yaml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
group: docs-deployment
|
group: docs-deployment
|
||||||
cancel-in-progress: false
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod
|
fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod
|
||||||
|
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5
|
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5
|
||||||
with:
|
with:
|
||||||
path: ~/.npm
|
path: ~/.npm
|
||||||
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
||||||
|
|||||||
@@ -30,14 +30,14 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout main branch (for latest templates and theme)
|
- name: Checkout main branch (for latest templates and theme)
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
ref: 'main'
|
ref: 'main'
|
||||||
submodules: 'recursive'
|
submodules: 'recursive'
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Checkout old content from tag into a temporary directory
|
- name: Checkout old content from tag into a temporary directory
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.inputs.version_tag }}
|
ref: ${{ github.event.inputs.version_tag }}
|
||||||
path: 'old_version_source' # Checkout into a temp subdir
|
path: 'old_version_source' # Checkout into a temp subdir
|
||||||
|
|||||||
2
.github/workflows/deploy_versioned_docs.yaml
vendored
2
.github/workflows/deploy_versioned_docs.yaml
vendored
@@ -30,7 +30,7 @@ jobs:
|
|||||||
cancel-in-progress: false
|
cancel-in-progress: false
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Code at Tag
|
- name: Checkout Code at Tag
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
ref: ${{ github.event.release.tag_name }}
|
ref: ${{ github.event.release.tag_name }}
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/docs_preview_clean.yaml
vendored
2
.github/workflows/docs_preview_clean.yaml
vendored
@@ -34,7 +34,7 @@ jobs:
|
|||||||
group: "preview-${{ github.event.number }}"
|
group: "preview-${{ github.event.number }}"
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
ref: versioned-gh-pages
|
ref: versioned-gh-pages
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/docs_preview_deploy.yaml
vendored
4
.github/workflows/docs_preview_deploy.yaml
vendored
@@ -49,7 +49,7 @@ jobs:
|
|||||||
group: "preview-${{ github.event.number }}"
|
group: "preview-${{ github.event.number }}"
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
with:
|
with:
|
||||||
# Checkout the PR's HEAD commit (supports forks).
|
# Checkout the PR's HEAD commit (supports forks).
|
||||||
ref: ${{ github.event.pull_request.head.sha }}
|
ref: ${{ github.event.pull_request.head.sha }}
|
||||||
@@ -67,7 +67,7 @@ jobs:
|
|||||||
node-version: "22"
|
node-version: "22"
|
||||||
|
|
||||||
- name: Cache dependencies
|
- name: Cache dependencies
|
||||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5
|
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5
|
||||||
with:
|
with:
|
||||||
path: ~/.npm
|
path: ~/.npm
|
||||||
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }}
|
||||||
|
|||||||
4
.github/workflows/link_checker_workflow.yaml
vendored
4
.github/workflows/link_checker_workflow.yaml
vendored
@@ -22,10 +22,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout Repository
|
- name: Checkout Repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
|
|
||||||
- name: Restore lychee cache
|
- name: Restore lychee cache
|
||||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5
|
uses: actions/cache@8b402f58fbc84540c8b491a91e594a4576fec3d7 # v5
|
||||||
with:
|
with:
|
||||||
path: .lycheecache
|
path: .lycheecache
|
||||||
key: cache-lychee-${{ github.sha }}
|
key: cache-lychee-${{ github.sha }}
|
||||||
|
|||||||
2
.github/workflows/publish-mcp.yml
vendored
2
.github/workflows/publish-mcp.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6
|
||||||
|
|
||||||
- name: Wait for image in Artifact Registry
|
- name: Wait for image in Artifact Registry
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ https://dev.mysql.com/doc/refman/8.4/en/sql-prepared-statements.html
|
|||||||
https://dev.mysql.com/doc/refman/8.4/en/user-names.html
|
https://dev.mysql.com/doc/refman/8.4/en/user-names.html
|
||||||
|
|
||||||
# npmjs links can occasionally trigger rate limiting during high-frequency CI builds
|
# npmjs links can occasionally trigger rate limiting during high-frequency CI builds
|
||||||
|
https://www.npmjs.com/package/@toolbox-sdk/server
|
||||||
https://www.npmjs.com/package/@toolbox-sdk/core
|
https://www.npmjs.com/package/@toolbox-sdk/core
|
||||||
https://www.npmjs.com/package/@toolbox-sdk/adk
|
https://www.npmjs.com/package/@toolbox-sdk/adk
|
||||||
https://www.oceanbase.com/
|
https://www.oceanbase.com/
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ export async function main() {
|
|||||||
|
|
||||||
for (const query of queries) {
|
for (const query of queries) {
|
||||||
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
conversationHistory.push({ role: "user", content: [{ text: query }] });
|
||||||
let response = await ai.generate({
|
const response = await ai.generate({
|
||||||
messages: conversationHistory,
|
messages: conversationHistory,
|
||||||
tools: tools,
|
tools: tools,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -13,12 +13,12 @@ The `invoke` command allows you to invoke tools defined in your configuration di
|
|||||||
|
|
||||||
{{< notice tip >}}
|
{{< notice tip >}}
|
||||||
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
**Keep configurations minimal:** The `invoke` command initializes *all* resources (sources, tools, etc.) defined in your configuration files during execution. To ensure fast response times, consider using a minimal configuration file containing only the tools you need for the specific invocation.
|
||||||
{{< /notice >}}
|
{{< notice tip >}}
|
||||||
|
|
||||||
## Before you begin
|
## Prerequisites
|
||||||
|
|
||||||
1. Make sure you have the `toolbox` binary installed or built.
|
- You have the `toolbox` binary installed or built.
|
||||||
2. Make sure you have a valid tool configuration file (e.g., `tools.yaml`).
|
- You have a valid tool configuration file (e.g., `tools.yaml`).
|
||||||
|
|
||||||
## Basic Usage
|
## Basic Usage
|
||||||
|
|
||||||
|
|||||||
@@ -414,10 +414,10 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
entries.
|
entries.
|
||||||
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
|
* **Dataplex Editor** (`roles/dataplex.editor`) to modify entries.
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `search_entries`: Searches for entries in Dataplex Catalog.
|
* `dataplex_search_entries`: Searches for entries in Dataplex Catalog.
|
||||||
* `lookup_entry`: Retrieves a specific entry from Dataplex
|
* `dataplex_lookup_entry`: Retrieves a specific entry from Dataplex
|
||||||
Catalog.
|
Catalog.
|
||||||
* `search_aspect_types`: Finds aspect types relevant to the
|
* `dataplex_search_aspect_types`: Finds aspect types relevant to the
|
||||||
query.
|
query.
|
||||||
|
|
||||||
## Firestore
|
## Firestore
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Cloud Logging Admin"
|
|
||||||
linkTitle: "Cloud Logging Admin"
|
|
||||||
type: docs
|
|
||||||
weight: 1
|
|
||||||
description: >
|
|
||||||
Tools that work with Cloud Logging Admin Sources.
|
|
||||||
---
|
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -233,11 +234,10 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401 or 403
|
// If auth error, return 401
|
||||||
var clientServerErr *util.ClientServerError
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
if errors.As(err, &clientServerErr) && (clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden) {
|
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, clientServerErr.Code))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||||
@@ -259,49 +259,34 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
var statusCode int
|
||||||
|
|
||||||
if errors.As(err, &tbErr) {
|
// Upstream API auth error propagation
|
||||||
switch tbErr.Category() {
|
switch {
|
||||||
case util.CategoryAgent:
|
case strings.Contains(errStr, "Error 401"):
|
||||||
// Agent Errors -> 200 OK
|
statusCode = http.StatusUnauthorized
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
case strings.Contains(errStr, "Error 403"):
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusOK))
|
statusCode = http.StatusForbidden
|
||||||
return
|
}
|
||||||
|
|
||||||
case util.CategoryServer:
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
// Server Errors -> Check the specific code inside
|
if clientAuth {
|
||||||
var clientServerErr *util.ClientServerError
|
// Propagate the original 401/403 error.
|
||||||
statusCode := http.StatusInternalServerError // Default to 500
|
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code != 0 {
|
|
||||||
statusCode = clientServerErr.Code
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process auth error
|
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
// Token error, pass through 401/403
|
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ADC/Config error, return 500
|
|
||||||
statusCode = http.StatusInternalServerError
|
|
||||||
}
|
|
||||||
|
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
// ADC lacking permission or credentials configuration error.
|
||||||
// Unknown error -> 500
|
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||||
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -443,12 +444,15 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
// Map Internal RPC Error (-32603) to HTTP 500
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
var clientServerErr *util.ClientServerError
|
errStr := err.Error()
|
||||||
if errors.As(err, &clientServerErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
w.WriteHeader(clientServerErr.Code)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
} else if strings.Contains(errStr, "Error 401") {
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
} else if strings.Contains(errStr, "Error 403") {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -202,44 +194,30 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
text := TextContent{
|
if clientAuth {
|
||||||
Type: "text",
|
// Error with client credentials should pass down to the client
|
||||||
Text: err.Error(),
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
|
||||||
return jsonrpc.JSONRPCResponse{
|
|
||||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
|
||||||
Id: id,
|
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
}
|
||||||
} else {
|
// Auth error with ADC should raise internal 500 error
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
text := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: err.Error(),
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -123,12 +124,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,11 +172,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -202,45 +194,31 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
text := TextContent{
|
if clientAuth {
|
||||||
Type: "text",
|
// Error with client credentials should pass down to the client
|
||||||
Text: err.Error(),
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
|
||||||
return jsonrpc.JSONRPCResponse{
|
|
||||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
|
||||||
Id: id,
|
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
}
|
||||||
} else {
|
// Auth error with ADC should raise internal 500 error
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
|
text := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: err.Error(),
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
text := TextContent{
|
if clientAuth {
|
||||||
Type: "text",
|
// Error with client credentials should pass down to the client
|
||||||
Text: err.Error(),
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
|
||||||
return jsonrpc.JSONRPCResponse{
|
|
||||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
|
||||||
Id: id,
|
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
}
|
||||||
} else {
|
// Auth error with ADC should raise internal 500 error
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
|
text := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: err.Error(),
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -116,12 +117,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
errMsg := "missing access token in the 'Authorization' header"
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, errMsg, nil), util.NewClientServerError(
|
|
||||||
errMsg,
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,11 +165,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = util.NewClientServerError(
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
|
||||||
http.StatusUnauthorized,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -195,44 +187,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var tbErr util.ToolboxError
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
if errors.As(err, &tbErr) {
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
switch tbErr.Category() {
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
case util.CategoryAgent:
|
}
|
||||||
// MCP - Tool execution error
|
// Upstream auth error
|
||||||
// Return SUCCESS but with IsError: true
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
text := TextContent{
|
if clientAuth {
|
||||||
Type: "text",
|
// Error with client credentials should pass down to the client
|
||||||
Text: err.Error(),
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
|
||||||
return jsonrpc.JSONRPCResponse{
|
|
||||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
|
||||||
Id: id,
|
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
case util.CategoryServer:
|
|
||||||
// MCP Spec - Protocol error
|
|
||||||
// Return JSON-RPC ERROR
|
|
||||||
var clientServerErr *util.ClientServerError
|
|
||||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
|
||||||
|
|
||||||
if errors.As(err, &clientServerErr) {
|
|
||||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
|
||||||
if clientAuth {
|
|
||||||
rpcCode = jsonrpc.INVALID_REQUEST
|
|
||||||
} else {
|
|
||||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
|
||||||
}
|
}
|
||||||
} else {
|
// Auth error with ADC should raise internal 500 error
|
||||||
// Unknown error -> 500
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
|
text := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: err.Error(),
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers: <nil>",
|
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package tools
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -81,7 +80,7 @@ type AccessToken string
|
|||||||
func (token AccessToken) ParseBearerToken() (string, error) {
|
func (token AccessToken) ParseBearerToken() (string, error) {
|
||||||
headerParts := strings.Split(string(token), " ")
|
headerParts := strings.Split(string(token), " ")
|
||||||
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
if len(headerParts) != 2 || strings.ToLower(headerParts[0]) != "bearer" {
|
||||||
return "", util.NewClientServerError("authorization header must be in the format 'Bearer <token>'", http.StatusUnauthorized, nil)
|
return "", fmt.Errorf("authorization header must be in the format 'Bearer <token>': %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
return headerParts[1], nil
|
return headerParts[1], nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
// Copyright 2026 Google LLC
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
package util
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
type ErrorCategory string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CategoryAgent ErrorCategory = "AGENT_ERROR"
|
|
||||||
CategoryServer ErrorCategory = "SERVER_ERROR"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ToolboxError is the interface all custom errors must satisfy
|
|
||||||
type ToolboxError interface {
|
|
||||||
error
|
|
||||||
Category() ErrorCategory
|
|
||||||
Error() string
|
|
||||||
Unwrap() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Agent Errors return 200 to the sender
|
|
||||||
type AgentError struct {
|
|
||||||
Msg string
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ToolboxError = &AgentError{}
|
|
||||||
|
|
||||||
func (e *AgentError) Error() string {
|
|
||||||
if e.Cause != nil {
|
|
||||||
return fmt.Sprintf("%s: %v", e.Msg, e.Cause)
|
|
||||||
}
|
|
||||||
return e.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *AgentError) Category() ErrorCategory { return CategoryAgent }
|
|
||||||
|
|
||||||
func (e *AgentError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewAgentError(msg string, cause error) *AgentError {
|
|
||||||
return &AgentError{Msg: msg, Cause: cause}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientServerError returns 4XX/5XX error code
|
|
||||||
type ClientServerError struct {
|
|
||||||
Msg string
|
|
||||||
Code int
|
|
||||||
Cause error
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ToolboxError = &ClientServerError{}
|
|
||||||
|
|
||||||
func (e *ClientServerError) Error() string {
|
|
||||||
if e.Cause != nil {
|
|
||||||
return fmt.Sprintf("%s: %v", e.Msg, e.Cause)
|
|
||||||
}
|
|
||||||
return e.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ClientServerError) Category() ErrorCategory { return CategoryServer }
|
|
||||||
|
|
||||||
func (e *ClientServerError) Unwrap() error { return e.Cause }
|
|
||||||
|
|
||||||
func NewClientServerError(msg string, code int, cause error) *ClientServerError {
|
|
||||||
return &ClientServerError{Msg: msg, Code: code, Cause: cause}
|
|
||||||
}
|
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -119,7 +118,7 @@ func parseFromAuthService(paramAuthServices []ParamAuthService, claimsMap map[st
|
|||||||
}
|
}
|
||||||
return v, nil
|
return v, nil
|
||||||
}
|
}
|
||||||
return nil, util.NewClientServerError("missing or invalid authentication header", http.StatusUnauthorized, nil)
|
return nil, fmt.Errorf("missing or invalid authentication header: %w", util.ErrUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
// CheckParamRequired checks if a parameter is required based on the required and default field.
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -187,3 +188,5 @@ func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ErrUnauthorized = errors.New("unauthorized")
|
||||||
|
|||||||
@@ -139,13 +139,24 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
// set up data for param tool
|
// set up data for param tool
|
||||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||||
defer teardownTable1(t)
|
if teardownTable1 != nil {
|
||||||
|
defer teardownTable1(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
|
||||||
defer teardownTable2(t)
|
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
|
if teardownTable2 != nil {
|
||||||
|
defer teardownTable2(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Set up table for semanti search
|
// Set up table for semanti search
|
||||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
|||||||
@@ -84,7 +84,6 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// create table name with UUID
|
// create table name with UUID
|
||||||
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||||
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||||
@@ -122,27 +121,42 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
// set up data for param tool
|
// set up data for param tool
|
||||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||||
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
teardownTable1, err := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup param table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable1(t)
|
defer teardownTable1(t)
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getBigQueryAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
|
teardownTable2, err := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup auth table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
// set up data for data type test tool
|
// set up data for data type test tool
|
||||||
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
|
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
|
||||||
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
|
teardownTable3, err := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup data type table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable3(t)
|
defer teardownTable3(t)
|
||||||
|
|
||||||
// set up data for forecast tool
|
// set up data for forecast tool
|
||||||
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
|
createForecastTableStmt, insertForecastTableStmt, forecastTestParams := getBigQueryForecastToolInfo(tableNameForecast)
|
||||||
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
|
teardownTable4, err := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup forecast table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable4(t)
|
defer teardownTable4(t)
|
||||||
|
|
||||||
// set up data for analyze contribution tool
|
// set up data for analyze contribution tool
|
||||||
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
|
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
|
||||||
teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
|
teardownTable5, err := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup analyze contribution table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable5(t)
|
defer teardownTable5(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
@@ -231,52 +245,79 @@ func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
|||||||
// Setup allowed table
|
// Setup allowed table
|
||||||
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
|
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
|
||||||
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
|
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
|
||||||
teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
|
teardownAllowed1, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed table 1: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowed1(t)
|
defer teardownAllowed1(t)
|
||||||
|
|
||||||
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
|
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
|
||||||
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
|
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
|
||||||
teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
|
teardownAllowed2, err:= setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed table 2: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowed2(t)
|
defer teardownAllowed2(t)
|
||||||
|
|
||||||
// Setup allowed forecast table
|
// Setup allowed forecast table
|
||||||
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
|
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
|
||||||
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
|
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
|
||||||
teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
|
teardownAllowedForecast1, err:= setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed forecast table 1: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowedForecast1(t)
|
defer teardownAllowedForecast1(t)
|
||||||
|
|
||||||
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
|
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
|
||||||
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
|
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
|
||||||
teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
|
teardownAllowedForecast2, err:= setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed forecast table 2: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowedForecast2(t)
|
defer teardownAllowedForecast2(t)
|
||||||
|
|
||||||
// Setup disallowed table
|
// Setup disallowed table
|
||||||
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
|
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
|
||||||
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
|
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
|
||||||
teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
|
teardownDisallowed, err:= setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup disallowed table: %s", err)
|
||||||
|
}
|
||||||
defer teardownDisallowed(t)
|
defer teardownDisallowed(t)
|
||||||
|
|
||||||
// Setup disallowed forecast table
|
// Setup disallowed forecast table
|
||||||
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
|
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
|
||||||
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
|
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
|
||||||
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
|
teardownDisallowedForecast, err:= setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup disallowed forecast table: %s", err)
|
||||||
|
}
|
||||||
defer teardownDisallowedForecast(t)
|
defer teardownDisallowedForecast(t)
|
||||||
|
|
||||||
// Setup allowed analyze contribution table
|
// Setup allowed analyze contribution table
|
||||||
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
|
allowedAnalyzeContributionTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedAnalyzeContributionTableName1)
|
||||||
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
|
createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, analyzeContributionParams1 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName1)
|
||||||
teardownAllowedAnalyzeContribution1 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
|
teardownAllowedAnalyzeContribution1, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt1, insertAnalyzeContributionStmt1, allowedDatasetName1, allowedAnalyzeContributionTableFullName1, analyzeContributionParams1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed analyze contribution table 1: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowedAnalyzeContribution1(t)
|
defer teardownAllowedAnalyzeContribution1(t)
|
||||||
|
|
||||||
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
|
allowedAnalyzeContributionTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedAnalyzeContributionTableName2)
|
||||||
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
|
createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, analyzeContributionParams2 := getBigQueryAnalyzeContributionToolInfo(allowedAnalyzeContributionTableFullName2)
|
||||||
teardownAllowedAnalyzeContribution2 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
|
teardownAllowedAnalyzeContribution2, err:= setupBigQueryTable(t, ctx, client, createAnalyzeContributionStmt2, insertAnalyzeContributionStmt2, allowedDatasetName2, allowedAnalyzeContributionTableFullName2, analyzeContributionParams2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup allowed analyze contribution table 2: %s", err)
|
||||||
|
}
|
||||||
defer teardownAllowedAnalyzeContribution2(t)
|
defer teardownAllowedAnalyzeContribution2(t)
|
||||||
|
|
||||||
// Setup disallowed analyze contribution table
|
// Setup disallowed analyze contribution table
|
||||||
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
|
disallowedAnalyzeContributionTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedAnalyzeContributionTableName)
|
||||||
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
|
createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedAnalyzeContributionParams := getBigQueryAnalyzeContributionToolInfo(disallowedAnalyzeContributionTableFullName)
|
||||||
teardownDisallowedAnalyzeContribution := setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
|
teardownDisallowedAnalyzeContribution, err:= setupBigQueryTable(t, ctx, client, createDisallowedAnalyzeContributionStmt, insertDisallowedAnalyzeContributionStmt, disallowedDatasetName, disallowedAnalyzeContributionTableFullName, disallowedAnalyzeContributionParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup disallowed analyze contribution table: %s", err)
|
||||||
|
}
|
||||||
defer teardownDisallowedAnalyzeContribution(t)
|
defer teardownDisallowedAnalyzeContribution(t)
|
||||||
|
|
||||||
// Configure source with dataset restriction.
|
// Configure source with dataset restriction.
|
||||||
@@ -438,7 +479,10 @@ func TestBigQueryWriteModeBlocked(t *testing.T) {
|
|||||||
t.Fatalf("unable to create BigQuery connection: %s", err)
|
t.Fatalf("unable to create BigQuery connection: %s", err)
|
||||||
}
|
}
|
||||||
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, _, _, _, _, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||||
teardownTable := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
teardownTable ,err:= setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup BigQuery table: %s", err)
|
||||||
|
}
|
||||||
defer teardownTable(t)
|
defer teardownTable(t)
|
||||||
|
|
||||||
toolsFile := map[string]any{
|
toolsFile := map[string]any{
|
||||||
@@ -623,7 +667,7 @@ func getBigQueryTmplToolStatement() (string, string) {
|
|||||||
return tmplSelectCombined, tmplSelectFilterCombined
|
return tmplSelectCombined, tmplSelectFilterCombined
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
|
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) (func(*testing.T), error) {
|
||||||
// Create dataset
|
// Create dataset
|
||||||
dataset := client.Dataset(datasetName)
|
dataset := client.Dataset(datasetName)
|
||||||
_, err := dataset.Metadata(ctx)
|
_, err := dataset.Metadata(ctx)
|
||||||
@@ -699,7 +743,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
|
|||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err)
|
t.Errorf("Failed to list tables in dataset %s to check emptiness: %v.", datasetName, err)
|
||||||
}
|
}
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any {
|
func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[string]any {
|
||||||
|
|||||||
@@ -124,13 +124,23 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
// set up data for param tool
|
// set up data for param tool
|
||||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||||
defer teardownTable1(t)
|
if teardownTable1 != nil {
|
||||||
|
defer teardownTable1(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
if teardownTable2 != nil {
|
||||||
|
defer teardownTable2(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Set up table for semantic search
|
// Set up table for semantic search
|
||||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
|||||||
@@ -613,31 +613,36 @@ func GetMySQLWants() (string, string, string, string) {
|
|||||||
|
|
||||||
// SetupPostgresSQLTable creates and inserts data into a table of tool
|
// SetupPostgresSQLTable creates and inserts data into a table of tool
|
||||||
// compatible with postgres-sql tool
|
// compatible with postgres-sql tool
|
||||||
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
|
||||||
err := pool.Ping(ctx)
|
err := pool.Ping(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to connect to test database: %s", err)
|
// Return nil for the function and the error itself
|
||||||
|
return nil, fmt.Errorf("unable to connect to test database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create table
|
// Create table
|
||||||
_, err = pool.Query(ctx, createStatement)
|
_, err = pool.Exec(ctx, createStatement)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
_, err = pool.Query(ctx, insertStatement, params...)
|
_, err = pool.Exec(ctx, insertStatement, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to insert test data: %s", err)
|
// partially cleanup if insert fails
|
||||||
|
teardown := func(t *testing.T) {
|
||||||
|
_, _ = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
|
||||||
|
}
|
||||||
|
return teardown, fmt.Errorf("unable to insert test data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the cleanup function and nil for error
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// tear down test
|
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
|
||||||
_, err = pool.Exec(ctx, fmt.Sprintf("DROP TABLE %s;", tableName))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Teardown failed: %s", err)
|
t.Errorf("Teardown failed: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupMsSQLTable creates and inserts data into a table of tool
|
// SetupMsSQLTable creates and inserts data into a table of tool
|
||||||
|
|||||||
@@ -89,12 +89,18 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
// set up data for param tool
|
// set up data for param tool
|
||||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getOracleParamToolInfo(tableNameParam)
|
||||||
teardownTable1 := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
teardownTable1, err := setupOracleTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Oracle table %s: %v", tableNameParam, err)
|
||||||
|
}
|
||||||
defer teardownTable1(t)
|
defer teardownTable1(t)
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getOracleAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2, err := setupOracleTable(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Oracle table %s: %v", tableNameAuth, err)
|
||||||
|
}
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
@@ -135,31 +141,31 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
|
|||||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
func setupOracleTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) (func(*testing.T), error) {
|
||||||
err := pool.PingContext(ctx)
|
err := pool.PingContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to connect to test database: %s", err)
|
return nil, fmt.Errorf("unable to connect to test database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create table
|
// Create table
|
||||||
_, err = pool.QueryContext(ctx, createStatement)
|
_, err = pool.QueryContext(ctx, createStatement)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
_, err = pool.QueryContext(ctx, insertStatement, params...)
|
_, err = pool.QueryContext(ctx, insertStatement, params...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to insert test data: %s", err)
|
return nil, fmt.Errorf("unable to insert test data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(t *testing.T) {
|
return func(t *testing.T) {
|
||||||
// tear down test
|
// tear down test
|
||||||
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s", tableName))
|
_, err = pool.ExecContext(ctx, fmt.Sprintf("DROP TABLE %s CASCADE CONSTRAINTS", tableName))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Teardown failed: %s", err)
|
t.Errorf("Teardown failed: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
func getOracleParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||||
|
|||||||
@@ -103,13 +103,24 @@ func TestPostgres(t *testing.T) {
|
|||||||
|
|
||||||
// set up data for param tool
|
// set up data for param tool
|
||||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
// teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||||
defer teardownTable1(t)
|
teardownTable1, err := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||||
|
if teardownTable1 != nil {
|
||||||
|
defer teardownTable1(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2, err := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
if teardownTable2 != nil {
|
||||||
|
defer teardownTable2(t)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Setup failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Set up table for semantic search
|
// Set up table for semantic search
|
||||||
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
|||||||
@@ -115,23 +115,35 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
|||||||
SpannerInstance,
|
SpannerInstance,
|
||||||
SpannerDatabase,
|
SpannerDatabase,
|
||||||
)
|
)
|
||||||
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams)
|
teardownTable1, err := setupSpannerTable(t, ctx, adminClient, dataClient, createParamTableStmt, insertParamTableStmt, tableNameParam, dbString, paramTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Spanner table %s: %v", tableNameParam, err)
|
||||||
|
}
|
||||||
defer teardownTable1(t)
|
defer teardownTable1(t)
|
||||||
|
|
||||||
// set up data for auth tool
|
// set up data for auth tool
|
||||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getSpannerAuthToolInfo(tableNameAuth)
|
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := getSpannerAuthToolInfo(tableNameAuth)
|
||||||
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams)
|
teardownTable2, err := setupSpannerTable(t, ctx, adminClient, dataClient, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, dbString, authTestParams)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Spanner table %s: %v", tableNameAuth, err)
|
||||||
|
}
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
// set up data for template param tool
|
// set up data for template param tool
|
||||||
createStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), age INT64) PRIMARY KEY (id)", tableNameTemplateParam)
|
createStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), age INT64) PRIMARY KEY (id)", tableNameTemplateParam)
|
||||||
teardownTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil)
|
teardownTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createStatementTmpl, "", tableNameTemplateParam, dbString, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Spanner table %s: %v", tableNameTemplateParam, err)
|
||||||
|
}
|
||||||
defer teardownTableTmpl(t)
|
defer teardownTableTmpl(t)
|
||||||
|
|
||||||
// set up for graph tool
|
// set up for graph tool
|
||||||
nodeTableName := "node_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
nodeTableName := "node_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
createNodeStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64 NOT NULL) PRIMARY KEY (id)", nodeTableName)
|
createNodeStatementTmpl := fmt.Sprintf("CREATE TABLE %s (id INT64 NOT NULL) PRIMARY KEY (id)", nodeTableName)
|
||||||
teardownNodeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil)
|
teardownNodeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createNodeStatementTmpl, "", nodeTableName, dbString, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Spanner table %s: %v", nodeTableName, err)
|
||||||
|
}
|
||||||
defer teardownNodeTableTmpl(t)
|
defer teardownNodeTableTmpl(t)
|
||||||
|
|
||||||
edgeTableName := "edge_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
edgeTableName := "edge_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
@@ -143,7 +155,10 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
|||||||
) PRIMARY KEY (id, target_id),
|
) PRIMARY KEY (id, target_id),
|
||||||
INTERLEAVE IN PARENT %[2]s ON DELETE CASCADE
|
INTERLEAVE IN PARENT %[2]s ON DELETE CASCADE
|
||||||
`, edgeTableName, nodeTableName)
|
`, edgeTableName, nodeTableName)
|
||||||
teardownEdgeTableTmpl := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil)
|
teardownEdgeTableTmpl, err := setupSpannerTable(t, ctx, adminClient, dataClient, createEdgeStatementTmpl, "", edgeTableName, dbString, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to setup Spanner table %s: %v", edgeTableName, err)
|
||||||
|
}
|
||||||
defer teardownEdgeTableTmpl(t)
|
defer teardownEdgeTableTmpl(t)
|
||||||
|
|
||||||
graphName := "graph_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
graphName := "graph_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
@@ -243,7 +258,7 @@ func getSpannerAuthToolInfo(tableName string) (string, string, string, map[strin
|
|||||||
|
|
||||||
// setupSpannerTable creates and inserts data into a table of tool
|
// setupSpannerTable creates and inserts data into a table of tool
|
||||||
// compatible with spanner-sql tool
|
// compatible with spanner-sql tool
|
||||||
func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) func(*testing.T) {
|
func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, createStatement, insertStatement, tableName, dbString string, params map[string]any) (func(*testing.T), error) {
|
||||||
|
|
||||||
// Create table
|
// Create table
|
||||||
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
|
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
|
||||||
@@ -251,11 +266,11 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
|
|||||||
Statements: []string{createStatement},
|
Statements: []string{createStatement},
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to start create table operation %s: %s", tableName, err)
|
return nil, fmt.Errorf("unable to start create table operation %s: %w", tableName, err)
|
||||||
}
|
}
|
||||||
err = op.Wait(ctx)
|
err = op.Wait(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
return nil, fmt.Errorf("unable to create test table %s: %w", tableName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert test data
|
// Insert test data
|
||||||
@@ -269,7 +284,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
|
|||||||
return err
|
return err
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to insert test data: %s", err)
|
return nil, fmt.Errorf("unable to insert test data: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,7 +303,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
|
|||||||
if opErr != nil {
|
if opErr != nil {
|
||||||
t.Errorf("Teardown failed: %s", opErr)
|
t.Errorf("Teardown failed: %s", opErr)
|
||||||
}
|
}
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupSpannerGraph creates a graph and inserts data into it.
|
// setupSpannerGraph creates a graph and inserts data into it.
|
||||||
|
|||||||
Reference in New Issue
Block a user