mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 00:49:08 -05:00
Compare commits
1 Commits
invoke-too
...
debug-opti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4416b14ee6 |
1
.github/workflows/docs_preview_deploy.yaml
vendored
1
.github/workflows/docs_preview_deploy.yaml
vendored
@@ -51,7 +51,6 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod
|
||||
|
||||
- name: Setup Hugo
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
@import 'td/code-dark';
|
||||
@@ -48,11 +48,7 @@ enableRobotsTXT = true
|
||||
pre = "<i class='fa-brands fa-github'></i>"
|
||||
|
||||
[markup.goldmark.renderer]
|
||||
unsafe= true
|
||||
|
||||
[markup.highlight]
|
||||
noClasses = false
|
||||
style = "tango"
|
||||
unsafe= true
|
||||
|
||||
[outputFormats]
|
||||
[outputFormats.LLMS]
|
||||
|
||||
32
CHANGELOG.md
32
CHANGELOG.md
@@ -1,37 +1,5 @@
|
||||
# Changelog
|
||||
|
||||
## [0.8.0](https://github.com/googleapis/genai-toolbox/compare/v0.7.0...v0.8.0) (2025-07-02)
|
||||
|
||||
|
||||
### ⚠ BREAKING CHANGES
|
||||
|
||||
* **postgres,mssql,cloudsqlmssql:** encode source connection url for sources ([#727](https://github.com/googleapis/genai-toolbox/issues/727))
|
||||
|
||||
### Features
|
||||
|
||||
* Add support for multiple YAML configuration files ([#760](https://github.com/googleapis/genai-toolbox/issues/760)) ([40679d7](https://github.com/googleapis/genai-toolbox/commit/40679d700eded50d19569923e2a71c51e907a8bf))
|
||||
* Add support for optional parameters ([#617](https://github.com/googleapis/genai-toolbox/issues/617)) ([4827771](https://github.com/googleapis/genai-toolbox/commit/4827771b78dee9a1284a898b749509b472061527)), closes [#475](https://github.com/googleapis/genai-toolbox/issues/475)
|
||||
* **mcp:** Support MCP version 2025-03-26 ([#755](https://github.com/googleapis/genai-toolbox/issues/755)) ([474df57](https://github.com/googleapis/genai-toolbox/commit/474df57d62de683079f8d12c31db53396a545fd1))
|
||||
* **sources/http:** Support disable SSL verification for HTTP Source ([#674](https://github.com/googleapis/genai-toolbox/issues/674)) ([4055b0c](https://github.com/googleapis/genai-toolbox/commit/4055b0c3569c527560d7ad34262963b3dd4e282d))
|
||||
* **tools/bigquery:** Add templateParameters field for bigquery ([#699](https://github.com/googleapis/genai-toolbox/issues/699)) ([f5f771b](https://github.com/googleapis/genai-toolbox/commit/f5f771b0f3d159630ff602ff55c6c66b61981446))
|
||||
* **tools/bigtable:** Add templateParameters field for bigtable ([#692](https://github.com/googleapis/genai-toolbox/issues/692)) ([1c06771](https://github.com/googleapis/genai-toolbox/commit/1c067715fac06479eb0060d7067b73dba099ed92))
|
||||
* **tools/couchbase:** Add templateParameters field for couchbase ([#723](https://github.com/googleapis/genai-toolbox/issues/723)) ([9197186](https://github.com/googleapis/genai-toolbox/commit/9197186b8bea1ac4ec1b39c9c5c110807c8b2ba9))
|
||||
* **tools/http:** Add support for HTTP Tool pathParams ([#726](https://github.com/googleapis/genai-toolbox/issues/726)) ([fd300dc](https://github.com/googleapis/genai-toolbox/commit/fd300dc606d88bf9f7bba689e2cee4e3565537dd))
|
||||
* **tools/redis:** Add Redis Source and Tool ([#519](https://github.com/googleapis/genai-toolbox/issues/519)) ([f0aef29](https://github.com/googleapis/genai-toolbox/commit/f0aef29b0c2563e2a00277fbe2784f39f16d2835))
|
||||
* **tools/spanner:** Add templateParameters field for spanner ([#691](https://github.com/googleapis/genai-toolbox/issues/691)) ([075dfa4](https://github.com/googleapis/genai-toolbox/commit/075dfa47e1fd92be4847bd0aec63296146b66455))
|
||||
* **tools/sqlitesql:** Add templateParameters field for sqlitesql ([#687](https://github.com/googleapis/genai-toolbox/issues/687)) ([75e254c](https://github.com/googleapis/genai-toolbox/commit/75e254c0a4ce690ca5fa4d1741550ce54734b226))
|
||||
* **tools/valkey:** Add Valkey Source and Tool ([#532](https://github.com/googleapis/genai-toolbox/issues/532)) ([054ec19](https://github.com/googleapis/genai-toolbox/commit/054ec198b97ba9f36f67dd12b2eff0cc6bc4d080))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* **bigquery,mssql:** Fix panic on tools with array param ([#722](https://github.com/googleapis/genai-toolbox/issues/722)) ([7a6644c](https://github.com/googleapis/genai-toolbox/commit/7a6644cf0c5413e5c803955c88a2cfd0a2233ed3))
|
||||
* **postgres,mssql,cloudsqlmssql:** Encode source connection url for sources ([#727](https://github.com/googleapis/genai-toolbox/issues/727)) ([67964d9](https://github.com/googleapis/genai-toolbox/commit/67964d939f27320b63b5759f4b3f3fdaa0c76fbf)), closes [#717](https://github.com/googleapis/genai-toolbox/issues/717)
|
||||
* Set default value to field's type during unmarshalling ([#774](https://github.com/googleapis/genai-toolbox/issues/774)) ([fafed24](https://github.com/googleapis/genai-toolbox/commit/fafed2485839cf1acc1350e8a24103d2e6356ee0)), closes [#771](https://github.com/googleapis/genai-toolbox/issues/771)
|
||||
* **server/mcp:** Do not listen from port for stdio ([#719](https://github.com/googleapis/genai-toolbox/issues/719)) ([d51dbc7](https://github.com/googleapis/genai-toolbox/commit/d51dbc759ba493021d3ec6f5417fc04c21f7044f)), closes [#711](https://github.com/googleapis/genai-toolbox/issues/711)
|
||||
* **tools/mysqlexecutesql:** Handle nil panic and connection leak in Invoke ([#757](https://github.com/googleapis/genai-toolbox/issues/757)) ([7badba4](https://github.com/googleapis/genai-toolbox/commit/7badba42eefb34252be77b852a57d6bd78dd267d))
|
||||
* **tools/mysqlsql:** Handle nil panic and connection leak in invoke ([#758](https://github.com/googleapis/genai-toolbox/issues/758)) ([cbb4a33](https://github.com/googleapis/genai-toolbox/commit/cbb4a333517313744800d148840312e56340f3fd))
|
||||
|
||||
## [0.7.0](https://github.com/googleapis/genai-toolbox/compare/v0.6.0...v0.7.0) (2025-06-10)
|
||||
|
||||
|
||||
|
||||
161
README.md
161
README.md
@@ -111,7 +111,7 @@ To install Toolbox as a binary:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.8.0
|
||||
export VERSION=0.7.0
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -124,7 +124,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.8.0
|
||||
export VERSION=0.7.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -137,7 +137,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.8.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.7.0
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -151,8 +151,6 @@ execute `toolbox` to start the server:
|
||||
```sh
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
> [!NOTE]
|
||||
> Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
|
||||
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
||||
terminate signal (`ctrl+c` on most platforms).
|
||||
@@ -167,12 +165,7 @@ Once your server is up and running, you can load the tools into your
|
||||
application. See below the list of Client SDKs for using various frameworks:
|
||||
|
||||
<details open>
|
||||
<summary>Python (<a href="https://github.com/googleapis/mcp-toolbox-sdk-python">Github</a>)</summary>
|
||||
<br>
|
||||
<blockquote>
|
||||
|
||||
<details open>
|
||||
<summary>Core</summary>
|
||||
<summary>Core</summary>
|
||||
|
||||
1. Install [Toolbox Core SDK][toolbox-core]:
|
||||
|
||||
@@ -198,9 +191,9 @@ For more detailed instructions on using the Toolbox Core SDK, see the
|
||||
[toolbox-core]: https://pypi.org/project/toolbox-core/
|
||||
[toolbox-core-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-core/README.md
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>LangChain / LangGraph</summary>
|
||||
</details>
|
||||
<details>
|
||||
<summary>LangChain / LangGraph</summary>
|
||||
|
||||
1. Install [Toolbox LangChain SDK][toolbox-langchain]:
|
||||
|
||||
@@ -220,15 +213,16 @@ For more detailed instructions on using the Toolbox Core SDK, see the
|
||||
tools = client.load_toolset()
|
||||
```
|
||||
|
||||
For more detailed instructions on using the Toolbox LangChain SDK, see the
|
||||
[project's README][toolbox-langchain-readme].
|
||||
For more detailed instructions on using the Toolbox LangChain SDK, see the
|
||||
[project's README][toolbox-langchain-readme].
|
||||
|
||||
[toolbox-langchain]: https://pypi.org/project/toolbox-langchain/
|
||||
[toolbox-langchain-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-langchain/README.md
|
||||
[toolbox-langchain]: https://pypi.org/project/toolbox-langchain/
|
||||
[toolbox-langchain-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-langchain/README.md
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>LlamaIndex</summary>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>LlamaIndex</summary>
|
||||
|
||||
1. Install [Toolbox Llamaindex SDK][toolbox-llamaindex]:
|
||||
|
||||
@@ -248,129 +242,12 @@ For more detailed instructions on using the Toolbox Core SDK, see the
|
||||
tools = client.load_toolset()
|
||||
```
|
||||
|
||||
For more detailed instructions on using the Toolbox Llamaindex SDK, see the
|
||||
[project's README][toolbox-llamaindex-readme].
|
||||
For more detailed instructions on using the Toolbox Llamaindex SDK, see the
|
||||
[project's README][toolbox-llamaindex-readme].
|
||||
|
||||
[toolbox-llamaindex]: https://pypi.org/project/toolbox-llamaindex/
|
||||
[toolbox-llamaindex-readme]: https://github.com/googleapis/genai-toolbox-llamaindex-python/blob/main/README.md
|
||||
[toolbox-llamaindex]: https://pypi.org/project/toolbox-llamaindex/
|
||||
[toolbox-llamaindex-readme]: https://github.com/googleapis/genai-toolbox-llamaindex-python/blob/main/README.md
|
||||
|
||||
</details>
|
||||
</details>
|
||||
</blockquote>
|
||||
<details>
|
||||
<summary>Javascript/Typescript (<a href="https://github.com/googleapis/mcp-toolbox-sdk-js">Github</a>)</summary>
|
||||
<br>
|
||||
<blockquote>
|
||||
|
||||
<details open>
|
||||
<summary>Core</summary>
|
||||
|
||||
1. Install [Toolbox Core SDK][toolbox-core-js]:
|
||||
|
||||
```bash
|
||||
npm install @toolbox-sdk/core
|
||||
```
|
||||
|
||||
1. Load tools:
|
||||
|
||||
```javascript
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const tools = await client.loadToolset('toolsetName');
|
||||
```
|
||||
|
||||
For more detailed instructions on using the Toolbox Core SDK, see the
|
||||
[project's README][toolbox-core-js-readme].
|
||||
|
||||
[toolbox-core-js]: https://www.npmjs.com/package/@toolbox-sdk/core
|
||||
[toolbox-core-js-readme]: https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>LangChain / LangGraph</summary>
|
||||
|
||||
1. Install [Toolbox Core SDK][toolbox-core-js]:
|
||||
|
||||
```bash
|
||||
npm install @toolbox-sdk/core
|
||||
```
|
||||
|
||||
2. Load tools:
|
||||
|
||||
```javascript
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
|
||||
// Define the basics of the tool: name, description, schema and core logic
|
||||
const getTool = (toolboxTool) => tool(currTool, {
|
||||
name: toolboxTool.getName(),
|
||||
description: toolboxTool.getDescription(),
|
||||
schema: toolboxTool.getParamSchema()
|
||||
});
|
||||
|
||||
// Use these tools in your Langchain/Langraph applications
|
||||
const tools = toolboxTools.map(getTool);
|
||||
```
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>Genkit</summary>
|
||||
|
||||
1. Install [Toolbox Core SDK][toolbox-core-js]:
|
||||
|
||||
```bash
|
||||
npm install @toolbox-sdk/core
|
||||
```
|
||||
|
||||
2. Load tools:
|
||||
|
||||
```javascript
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
import { genkit } from 'genkit';
|
||||
|
||||
// Initialise genkit
|
||||
const ai = genkit({
|
||||
plugins: [
|
||||
googleAI({
|
||||
apiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY
|
||||
})
|
||||
],
|
||||
model: googleAI.model('gemini-2.0-flash'),
|
||||
});
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
|
||||
// Define the basics of the tool: name, description, schema and core logic
|
||||
const getTool = (toolboxTool) => ai.defineTool({
|
||||
name: toolboxTool.getName(),
|
||||
description: toolboxTool.getDescription(),
|
||||
schema: toolboxTool.getParamSchema()
|
||||
}, toolboxTool)
|
||||
|
||||
// Use these tools in your Genkit applications
|
||||
const tools = toolboxTools.map(getTool);
|
||||
```
|
||||
|
||||
</details>
|
||||
</details>
|
||||
</blockquote>
|
||||
|
||||
</details>
|
||||
|
||||
## Configuration
|
||||
|
||||
237
cmd/root.go
237
cmd/root.go
@@ -19,26 +19,20 @@ import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
|
||||
// Import tool packages for side effect of registration
|
||||
@@ -184,8 +178,6 @@ func NewCommand(opts ...Option) *Command {
|
||||
flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||
flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", "Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: 'alloydb-postgres', 'bigquery', 'cloud-sql-mysql', 'cloud-sql-postgres', 'cloud-sql-mssql', 'postgres', 'spanner', 'spanner-postgres'.")
|
||||
flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.")
|
||||
flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.")
|
||||
flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.")
|
||||
|
||||
// wrap RunE command so that we have access to original Command object
|
||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||
@@ -355,7 +347,7 @@ func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile,
|
||||
|
||||
// Combine both file lists
|
||||
allFiles := append(yamlFiles, ymlFiles...)
|
||||
|
||||
|
||||
if len(allFiles) == 0 {
|
||||
return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath)
|
||||
}
|
||||
@@ -364,177 +356,6 @@ func loadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile,
|
||||
return loadAndMergeToolsFiles(ctx, allFiles)
|
||||
}
|
||||
|
||||
func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Server) error {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
|
||||
func validateReloadEdits(
|
||||
ctx context.Context, toolsFile ToolsFile,
|
||||
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, error,
|
||||
) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
logger.DebugContext(ctx, "Attempting to parse and validate reloaded tools file.")
|
||||
|
||||
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/reload")
|
||||
defer span.End()
|
||||
|
||||
reloadedConfig := server.ServerConfig{
|
||||
Version: versionString,
|
||||
SourceConfigs: toolsFile.Sources,
|
||||
AuthServiceConfigs: toolsFile.AuthServices,
|
||||
ToolConfigs: toolsFile.Tools,
|
||||
ToolsetConfigs: toolsFile.Toolsets,
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, nil
|
||||
}
|
||||
|
||||
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
|
||||
func watchChanges(ctx context.Context, watchDirs map[string]bool, watchedFiles map[string]bool, s *server.Server) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
logger.WarnContext(ctx, "error setting up new watcher %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer w.Close()
|
||||
|
||||
watchingFolder := false
|
||||
var folderToWatch string
|
||||
|
||||
// if watchedFiles is empty, indicates that user passed entire folder instead
|
||||
if len(watchedFiles) == 0 {
|
||||
watchingFolder = true
|
||||
|
||||
// validate that watchDirs only has single element
|
||||
if len(watchDirs) > 1 {
|
||||
logger.WarnContext(ctx, "error setting watcher, expected single tools folder if no file(s) are defined.")
|
||||
return
|
||||
}
|
||||
|
||||
for onlyKey := range watchDirs {
|
||||
folderToWatch = onlyKey
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
for dir := range watchDirs {
|
||||
err := w.Add(dir)
|
||||
if err != nil {
|
||||
logger.WarnContext(ctx, fmt.Sprintf("Error adding path %s to watcher: %s", dir, err))
|
||||
break
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("Added directory %s to watcher.", dir))
|
||||
}
|
||||
|
||||
// debounce timer is used to prevent multiple writes triggering multiple reloads
|
||||
debounceDelay := 100 * time.Millisecond
|
||||
debounce := time.NewTimer(1 * time.Minute)
|
||||
debounce.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.DebugContext(ctx, "file watcher context cancelled")
|
||||
return
|
||||
case err, ok := <-w.Errors:
|
||||
if !ok {
|
||||
logger.WarnContext(ctx, "file watcher was closed unexpectedly")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.WarnContext(ctx, "file watcher error %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
case e, ok := <-w.Events:
|
||||
if !ok {
|
||||
logger.WarnContext(ctx, "file watcher already closed")
|
||||
return
|
||||
}
|
||||
|
||||
// only check for write events which indicate user saved a new tools file
|
||||
if !e.Has(fsnotify.Write) {
|
||||
continue
|
||||
}
|
||||
|
||||
cleanedFilename := filepath.Clean(e.Name)
|
||||
logger.DebugContext(ctx, fmt.Sprintf("WRITE event detected in %s", cleanedFilename))
|
||||
|
||||
folderChanged := watchingFolder &&
|
||||
(strings.HasSuffix(cleanedFilename, ".yaml") || strings.HasSuffix(cleanedFilename, ".yml"))
|
||||
|
||||
if folderChanged || watchedFiles[cleanedFilename] {
|
||||
// indicates the write event is on a relevant file
|
||||
debounce.Reset(debounceDelay)
|
||||
}
|
||||
|
||||
case <-debounce.C:
|
||||
debounce.Stop()
|
||||
var reloadedToolsFile ToolsFile
|
||||
|
||||
if watchingFolder {
|
||||
logger.DebugContext(ctx, "Reloading tools folder.")
|
||||
reloadedToolsFile, err = loadAndMergeToolsFolder(ctx, folderToWatch)
|
||||
if err != nil {
|
||||
logger.WarnContext(ctx, "error loading tools folder %s", err)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
logger.DebugContext(ctx, "Reloading tools file(s).")
|
||||
reloadedToolsFile, err = loadAndMergeToolsFiles(ctx, slices.Collect(maps.Keys(watchedFiles)))
|
||||
if err != nil {
|
||||
logger.WarnContext(ctx, "error loading tools files %s", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
err = handleDynamicReload(ctx, reloadedToolsFile, s)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse reloaded tools file at %q: %w", reloadedToolsFile, err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateLogLevel checks if Toolbox have to update the existing log level set by users.
|
||||
// stdio doesn't support "debug" and "info" logs.
|
||||
func updateLogLevel(stdio bool, logLevel string) bool {
|
||||
@@ -549,33 +370,6 @@ func updateLogLevel(stdio bool, logLevel string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func resolveWatcherInputs(toolsFile string, toolsFiles []string, toolsFolder string) (map[string]bool, map[string]bool) {
|
||||
var relevantFiles []string
|
||||
|
||||
// map for efficiently checking if a file is relevant
|
||||
watchedFiles := make(map[string]bool)
|
||||
|
||||
// dirs that will be added to watcher (fsnotify prefers watching directory then filtering for file)
|
||||
watchDirs := make(map[string]bool)
|
||||
|
||||
if len(toolsFiles) > 0 {
|
||||
relevantFiles = toolsFiles
|
||||
} else if toolsFolder != "" {
|
||||
watchDirs[filepath.Clean(toolsFolder)] = true
|
||||
} else {
|
||||
relevantFiles = []string{toolsFile}
|
||||
}
|
||||
|
||||
// extract parent dir for relevant files and dedup
|
||||
for _, f := range relevantFiles {
|
||||
cleanFile := filepath.Clean(f)
|
||||
watchedFiles[cleanFile] = true
|
||||
watchDirs[filepath.Dir(cleanFile)] = true
|
||||
}
|
||||
|
||||
return watchDirs, watchedFiles
|
||||
}
|
||||
|
||||
func run(cmd *Command) error {
|
||||
if updateLogLevel(cmd.cfg.Stdio, cmd.cfg.LogLevel.String()) {
|
||||
cmd.cfg.LogLevel = server.StringLevel(log.Warn)
|
||||
@@ -672,7 +466,6 @@ func run(cmd *Command) error {
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Use multiple tools files
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files)))
|
||||
var err error
|
||||
@@ -688,7 +481,6 @@ func run(cmd *Command) error {
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
// Use tools folder
|
||||
cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder))
|
||||
var err error
|
||||
@@ -702,7 +494,6 @@ func run(cmd *Command) error {
|
||||
if cmd.tools_file == "" {
|
||||
cmd.tools_file = "tools.yaml"
|
||||
}
|
||||
|
||||
// Read single tool file contents
|
||||
buf, err := os.ReadFile(cmd.tools_file)
|
||||
if err != nil {
|
||||
@@ -725,23 +516,9 @@ func run(cmd *Command) error {
|
||||
cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead")
|
||||
cmd.cfg.AuthServiceConfigs = authSourceConfigs
|
||||
}
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||
|
||||
// start server
|
||||
s, err := server.NewServer(ctx, cmd.cfg)
|
||||
s, err := server.NewServer(ctx, cmd.cfg, cmd.logger)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("toolbox failed to initialize: %w", err)
|
||||
cmd.logger.ErrorContext(ctx, errMsg.Error())
|
||||
@@ -766,9 +543,6 @@ func run(cmd *Command) error {
|
||||
return errMsg
|
||||
}
|
||||
cmd.logger.InfoContext(ctx, "Server ready to serve!")
|
||||
if cmd.cfg.UI {
|
||||
cmd.logger.InfoContext(ctx, "Toolbox UI is up and running at: http://localhost:5000/ui")
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(srvErr)
|
||||
@@ -779,13 +553,6 @@ func run(cmd *Command) error {
|
||||
}()
|
||||
}
|
||||
|
||||
watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder)
|
||||
|
||||
if !cmd.cfg.DisableReload {
|
||||
// start watching the file(s) or folder for changes to trigger dynamic reloading
|
||||
go watchChanges(ctx, watchDirs, watchedFiles, s)
|
||||
}
|
||||
|
||||
// wait for either the server to error out or the command's context to be canceled
|
||||
select {
|
||||
case err := <-srvErr:
|
||||
|
||||
194
cmd/root_test.go
194
cmd/root_test.go
@@ -16,33 +16,23 @@ package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -184,13 +174,6 @@ func TestServerConfigFlags(t *testing.T) {
|
||||
Stdio: true,
|
||||
}),
|
||||
},
|
||||
{
|
||||
desc: "disable reload",
|
||||
args: []string{"--disable-reload"},
|
||||
want: withDefaults(server.ServerConfig{
|
||||
DisableReload: true,
|
||||
}),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -982,183 +965,6 @@ func TestEnvVarReplacement(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
// normalizeFilepaths is a helper function to allow same filepath formats for Mac and Windows.
|
||||
// this prevents needing multiple "want" cases for TestResolveWatcherInputs
|
||||
func normalizeFilepaths(m map[string]bool) map[string]bool {
|
||||
newMap := make(map[string]bool)
|
||||
for k, v := range m {
|
||||
newMap[filepath.ToSlash(k)] = v
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
func TestResolveWatcherInputs(t *testing.T) {
|
||||
tcs := []struct {
|
||||
description string
|
||||
toolsFile string
|
||||
toolsFiles []string
|
||||
toolsFolder string
|
||||
wantWatchDirs map[string]bool
|
||||
wantWatchedFiles map[string]bool
|
||||
}{
|
||||
{
|
||||
description: "single tools file",
|
||||
toolsFile: "tools_folder/example_tools.yaml",
|
||||
toolsFiles: []string{},
|
||||
toolsFolder: "",
|
||||
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||
wantWatchedFiles: map[string]bool{"tools_folder/example_tools.yaml": true},
|
||||
},
|
||||
{
|
||||
description: "default tools file (root dir)",
|
||||
toolsFile: "tools.yaml",
|
||||
toolsFiles: []string{},
|
||||
toolsFolder: "",
|
||||
wantWatchDirs: map[string]bool{".": true},
|
||||
wantWatchedFiles: map[string]bool{"tools.yaml": true},
|
||||
},
|
||||
{
|
||||
description: "multiple files in different folders",
|
||||
toolsFile: "",
|
||||
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder2/example_tools.yaml"},
|
||||
toolsFolder: "",
|
||||
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder2": true},
|
||||
wantWatchedFiles: map[string]bool{
|
||||
"tools_folder/example_tools.yaml": true,
|
||||
"tools_folder2/example_tools.yaml": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "multiple files in same folder",
|
||||
toolsFile: "",
|
||||
toolsFiles: []string{"tools_folder/example_tools.yaml", "tools_folder/example_tools2.yaml"},
|
||||
toolsFolder: "",
|
||||
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||
wantWatchedFiles: map[string]bool{
|
||||
"tools_folder/example_tools.yaml": true,
|
||||
"tools_folder/example_tools2.yaml": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "multiple files in different levels",
|
||||
toolsFile: "",
|
||||
toolsFiles: []string{
|
||||
"tools_folder/example_tools.yaml",
|
||||
"tools_folder/special_tools/example_tools2.yaml"},
|
||||
toolsFolder: "",
|
||||
wantWatchDirs: map[string]bool{"tools_folder": true, "tools_folder/special_tools": true},
|
||||
wantWatchedFiles: map[string]bool{
|
||||
"tools_folder/example_tools.yaml": true,
|
||||
"tools_folder/special_tools/example_tools2.yaml": true,
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "tools folder",
|
||||
toolsFile: "",
|
||||
toolsFiles: []string{},
|
||||
toolsFolder: "tools_folder",
|
||||
wantWatchDirs: map[string]bool{"tools_folder": true},
|
||||
wantWatchedFiles: map[string]bool{},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
gotWatchDirs, gotWatchedFiles := resolveWatcherInputs(tc.toolsFile, tc.toolsFiles, tc.toolsFolder)
|
||||
|
||||
normalizedGotWatchDirs := normalizeFilepaths(gotWatchDirs)
|
||||
normalizedGotWatchedFiles := normalizeFilepaths(gotWatchedFiles)
|
||||
|
||||
if diff := cmp.Diff(tc.wantWatchDirs, normalizedGotWatchDirs); diff != "" {
|
||||
t.Errorf("incorrect watchDirs: diff %v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.wantWatchedFiles, normalizedGotWatchedFiles); diff != "" {
|
||||
t.Errorf("incorrect watchedFiles: diff %v", diff)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// helper function for testing file detection in dynamic reloading
|
||||
func tmpFileWithCleanup(content []byte) (string, func(), error) {
|
||||
f, err := os.CreateTemp("", "*")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
cleanup := func() { os.Remove(f.Name()) }
|
||||
|
||||
if _, err := f.Write(content); err != nil {
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
cleanup()
|
||||
return "", nil, err
|
||||
}
|
||||
return f.Name(), cleanup, err
|
||||
}
|
||||
|
||||
func TestSingleEdit(t *testing.T) {
|
||||
ctx, cancelCtx := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancelCtx()
|
||||
|
||||
pr, pw := io.Pipe()
|
||||
defer pw.Close()
|
||||
defer pr.Close()
|
||||
|
||||
fileToWatch, cleanup, err := tmpFileWithCleanup([]byte("initial content"))
|
||||
if err != nil {
|
||||
t.Fatalf("error editing tools file %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
logger, err := log.NewStdLogger(pw, pw, "DEBUG")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup logger %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to setup instrumentation %s", err)
|
||||
}
|
||||
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||
|
||||
mockServer := &server.Server{}
|
||||
|
||||
cleanFileToWatch := filepath.Clean(fileToWatch)
|
||||
watchDir := filepath.Dir(cleanFileToWatch)
|
||||
|
||||
watchedFiles := map[string]bool{cleanFileToWatch: true}
|
||||
watchDirs := map[string]bool{watchDir: true}
|
||||
|
||||
go watchChanges(ctx, watchDirs, watchedFiles, mockServer)
|
||||
|
||||
// escape backslash so regex doesn't fail on windows filepaths
|
||||
regexEscapedPathFile := strings.ReplaceAll(cleanFileToWatch, `\`, `\\\\*\\`)
|
||||
regexEscapedPathFile = path.Clean(regexEscapedPathFile)
|
||||
|
||||
regexEscapedPathDir := strings.ReplaceAll(watchDir, `\`, `\\\\*\\`)
|
||||
regexEscapedPathDir = path.Clean(regexEscapedPathDir)
|
||||
|
||||
begunWatchingDir := regexp.MustCompile(fmt.Sprintf(`DEBUG "Added directory %s to watcher."`, regexEscapedPathDir))
|
||||
_, err = testutils.WaitForString(ctx, begunWatchingDir, pr)
|
||||
if err != nil {
|
||||
t.Fatalf("timeout or error waiting for watcher to start: %s", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(fileToWatch, []byte("modification"), 0777)
|
||||
if err != nil {
|
||||
t.Fatalf("error writing to file: %v", err)
|
||||
}
|
||||
|
||||
detectedFileChange := regexp.MustCompile(fmt.Sprintf(`DEBUG "WRITE event detected in %s"`, regexEscapedPathFile))
|
||||
_, err = testutils.WaitForString(ctx, detectedFileChange, pr)
|
||||
if err != nil {
|
||||
t.Fatalf("timeout or error waiting for file to detect write: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrebuiltTools(t *testing.T) {
|
||||
alloydb_config, _ := prebuiltconfigs.Get("alloydb-postgres")
|
||||
bigquery_config, _ := prebuiltconfigs.Get("bigquery")
|
||||
|
||||
@@ -1 +1 @@
|
||||
0.8.0
|
||||
0.7.0
|
||||
|
||||
@@ -222,7 +222,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.8.0\" # x-release-please-version\n",
|
||||
"version = \"0.7.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -86,7 +86,7 @@ To install Toolbox as a binary:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.8.0
|
||||
export VERSION=0.7.0
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -97,7 +97,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.8.0
|
||||
export VERSION=0.7.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -108,7 +108,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.8.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.7.0
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
@@ -123,9 +123,6 @@ execute `toolbox` to start the server:
|
||||
```sh
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
{{< notice note >}}
|
||||
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
{{< /notice >}}
|
||||
|
||||
You can use `toolbox help` for a full list of flags! To stop the server, send a
|
||||
terminate signal (`ctrl+c` on most platforms).
|
||||
@@ -138,7 +135,6 @@ out the resources in the [How-to section](../../how-to/_index.md)
|
||||
Once your server is up and running, you can load the tools into your
|
||||
application. See below the list of Client SDKs for using various frameworks:
|
||||
|
||||
#### Python
|
||||
{{< tabpane text=true persist=header >}}
|
||||
{{% tab header="Core" lang="en" %}}
|
||||
|
||||
@@ -205,115 +201,3 @@ For more detailed instructions on using the Toolbox Llamaindex SDK, see the
|
||||
|
||||
{{% /tab %}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
#### Javascript/Typescript
|
||||
|
||||
Once you've installed the [Toolbox Core
|
||||
SDK](https://www.npmjs.com/package/@toolbox-sdk/core), you can load
|
||||
tools:
|
||||
|
||||
{{< tabpane text=true persist=header >}}
|
||||
{{% tab header="Core" lang="en" %}}
|
||||
|
||||
{{< highlight javascript >}}
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
{{< /highlight >}}
|
||||
|
||||
{{% /tab %}}
|
||||
{{% tab header="LangChain/Langraph" lang="en" %}}
|
||||
|
||||
{{< highlight javascript >}}
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
|
||||
// Define the basics of the tool: name, description, schema and core logic
|
||||
const getTool = (toolboxTool) => tool(currTool, {
|
||||
name: toolboxTool.getName(),
|
||||
description: toolboxTool.getDescription(),
|
||||
schema: toolboxTool.getParamSchema()
|
||||
});
|
||||
|
||||
// Use these tools in your Langchain/Langraph applications
|
||||
const tools = toolboxTools.map(getTool);
|
||||
{{< /highlight >}}
|
||||
|
||||
{{% /tab %}}
|
||||
{{% tab header="Genkit" lang="en" %}}
|
||||
|
||||
{{< highlight javascript >}}
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
import { genkit } from 'genkit';
|
||||
|
||||
// Initialise genkit
|
||||
const ai = genkit({
|
||||
plugins: [
|
||||
googleAI({
|
||||
apiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY
|
||||
})
|
||||
],
|
||||
model: googleAI.model('gemini-2.0-flash'),
|
||||
});
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
|
||||
// Define the basics of the tool: name, description, schema and core logic
|
||||
const getTool = (toolboxTool) => ai.defineTool({
|
||||
name: toolboxTool.getName(),
|
||||
description: toolboxTool.getDescription(),
|
||||
schema: toolboxTool.getParamSchema()
|
||||
}, toolboxTool)
|
||||
|
||||
// Use these tools in your Genkit applications
|
||||
const tools = toolboxTools.map(getTool);
|
||||
{{< /highlight >}}
|
||||
|
||||
{{% /tab %}}
|
||||
{{% tab header="LlamaIndex" lang="en" %}}
|
||||
|
||||
{{< highlight javascript >}}
|
||||
import { ToolboxClient } from '@toolbox-sdk/core';
|
||||
import { tool } from "llamaindex";
|
||||
|
||||
// update the url to point to your server
|
||||
const URL = 'http://127.0.0.1:5000';
|
||||
let client = new ToolboxClient(URL);
|
||||
|
||||
// these tools can be passed to your application!
|
||||
const toolboxTools = await client.loadToolset('toolsetName');
|
||||
|
||||
// Define the basics of the tool: name, description, schema and core logic
|
||||
const getTool = (toolboxTool) => tool({
|
||||
name: toolboxTool.getName(),
|
||||
description: toolboxTool.getDescription(),
|
||||
parameters: toolboxTool.getParams(),
|
||||
execute: toolboxTool
|
||||
});;
|
||||
|
||||
// Use these tools in your LlamaIndex applications
|
||||
const tools = toolboxTools.map(getTool);
|
||||
|
||||
{{< /highlight >}}
|
||||
|
||||
{{% /tab %}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
For more detailed instructions on using the Toolbox Core SDK, see the
|
||||
[project's README](https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md).
|
||||
@@ -156,7 +156,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -257,9 +257,6 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
{{< notice note >}}
|
||||
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
{{< /notice >}}
|
||||
|
||||
## Step 3: Connect your agent to Toolbox
|
||||
|
||||
|
||||
@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -52,19 +52,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/linux/amd64/toolbox>
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/linux/amd64/toolbox>
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/darwin/arm64/toolbox>
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/darwin/arm64/toolbox>
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/darwin/amd64/toolbox>
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/darwin/amd64/toolbox>
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/windows/amd64/toolbox>
|
||||
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/windows/amd64/toolbox>
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -63,10 +63,6 @@ When running with stdio, Toolbox will listen via stdio instead of acting as a
|
||||
remote HTTP server. Logs will be set to the `warn` level by default. `debug` and
|
||||
`info` logs are not supported with stdio.
|
||||
|
||||
{{< notice note >}}
|
||||
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
{{< /notice >}}
|
||||
|
||||
### Connecting via HTTP
|
||||
|
||||
Toolbox supports the HTTP transport protocol with and without SSE.
|
||||
|
||||
@@ -79,7 +79,7 @@ database are in the same VPC network.
|
||||
|
||||
Create a `tools.yaml` file that contains your configuration for Toolbox. For
|
||||
details, see the
|
||||
[configuration](https://googleapis.github.io/genai-toolbox/resources/sources/)
|
||||
[configuration](https://github.com/googleapis/genai-toolbox/blob/main/README.md#configuration)
|
||||
section.
|
||||
|
||||
## Deploy to Cloud Run
|
||||
|
||||
@@ -15,10 +15,8 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
|
||||
`bigquery-get-dataset-info` takes a `dataset` parameter to specify the dataset
|
||||
on the given source. It also optionally accepts a `project` parameter to
|
||||
define the Google Cloud project ID. If the `project` parameter is not provided,
|
||||
the tool defaults to using the project defined in the source configuration.
|
||||
bigquery-get-dataset-info takes a dataset parameter to specify the dataset
|
||||
on the given source.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -15,10 +15,8 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
|
||||
`bigquery-get-table-info` takes `dataset` and `table` parameters to specify
|
||||
the target table. It also optionally accepts a `project` parameter to define
|
||||
the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
bigquery-get-table-info takes dataset and table parameters to specify
|
||||
the target table.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -15,9 +15,8 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
|
||||
`bigquery-list-dataset-ids` optionally accepts a `project` parameter to define
|
||||
the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
bigquery-list-dataset-ids requires no input parameters beyond the configured
|
||||
source.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -15,10 +15,8 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../sources/bigquery.md)
|
||||
|
||||
`bigquery-get-dataset-info` takes a required `dataset` parameter to specify the dataset
|
||||
from which to list table IDs. It also optionally accepts a `project` parameter to
|
||||
define the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
bigquery-get-dataset-info takes a dataset parameter to specify the dataset
|
||||
from which to list table IDs.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.8.0\" # x-release-please-version\n",
|
||||
"version = \"0.7.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -292,9 +292,6 @@ to use BigQuery, and then run the Toolbox server.
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
{{< notice note >}}
|
||||
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
{{< /notice >}}
|
||||
|
||||
## Step 3: Connect your agent to Toolbox
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
10
go.mod
10
go.mod
@@ -7,19 +7,17 @@ toolchain go1.24.4
|
||||
require (
|
||||
cloud.google.com/go/alloydbconn v1.15.3
|
||||
cloud.google.com/go/bigquery v1.69.0
|
||||
cloud.google.com/go/bigtable v1.38.0
|
||||
cloud.google.com/go/bigtable v1.37.0
|
||||
cloud.google.com/go/cloudsqlconn v1.17.2
|
||||
cloud.google.com/go/spanner v1.83.0
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0
|
||||
github.com/couchbase/gocb/v2 v2.10.0
|
||||
github.com/couchbase/tools-common/http v1.0.9
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-chi/chi/v5 v5.2.2
|
||||
github.com/go-chi/httplog/v2 v2.1.1
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/go-goquery/goquery v1.0.1
|
||||
github.com/go-playground/validator/v10 v10.27.0
|
||||
github.com/go-playground/validator/v10 v10.26.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
@@ -40,7 +38,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/metric v1.36.0
|
||||
go.opentelemetry.io/otel/trace v1.36.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
google.golang.org/api v0.240.0
|
||||
google.golang.org/api v0.239.0
|
||||
modernc.org/sqlite v1.38.0
|
||||
)
|
||||
|
||||
@@ -61,9 +59,7 @@ require (
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.3 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/andybalholm/cascadia v1.3.3 // indirect
|
||||
github.com/apache/arrow/go/v15 v15.0.2 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
|
||||
54
go.sum
54
go.sum
@@ -139,8 +139,8 @@ cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9
|
||||
cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU=
|
||||
cloud.google.com/go/bigquery v1.69.0 h1:rZvHnjSUs5sHK3F9awiuFk2PeOaB8suqNuim21GbaTc=
|
||||
cloud.google.com/go/bigquery v1.69.0/go.mod h1:TdGLquA3h/mGg+McX+GsqG9afAzTAcldMjqhdjHTLew=
|
||||
cloud.google.com/go/bigtable v1.38.0 h1:L/PnUXRtAzFfa7qMULJHt4cXa/O2dqPJEkzYNGA4hfo=
|
||||
cloud.google.com/go/bigtable v1.38.0/go.mod h1:o/lntJarF3Y5C0XYLMJLjLYwxaRbcrtM0BiV57ymXbI=
|
||||
cloud.google.com/go/bigtable v1.37.0 h1:Q+x7y04lQ0B+WXp03wc1/FLhFt4CwcQdkwWT0M4Jp3w=
|
||||
cloud.google.com/go/bigtable v1.37.0/go.mod h1:HXqddP6hduwzrtiTCqZPpj9ij4hGZb4Zy1WF/dT+yaU=
|
||||
cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY=
|
||||
cloud.google.com/go/billing v1.5.0/go.mod h1:mztb1tBc3QekhjSgmpf/CV4LzWXLzCArwpLmP2Gm88s=
|
||||
cloud.google.com/go/billing v1.6.0/go.mod h1:WoXzguj+BeHXPbKfNWkqVtDdzORazmCjraY+vrxcyvI=
|
||||
@@ -661,8 +661,6 @@ github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapp
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0/go.mod h1:cSgYe11MCNYunTnRXrKiR/tHc0eoKjICUuWpNZoVCOo=
|
||||
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk=
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
|
||||
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY=
|
||||
@@ -670,8 +668,6 @@ github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T
|
||||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||
github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM=
|
||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
|
||||
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
|
||||
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
github.com/apache/arrow/go/v10 v10.0.1/go.mod h1:YvhnlEePVnBS4+0z3fhPfUy7W1Ikj0Ih0vcRo/gZ1M0=
|
||||
github.com/apache/arrow/go/v11 v11.0.0/go.mod h1:Eg5OsL5H+e299f7u5ssuXsuHQVEGC4xei5aX110hRiI=
|
||||
@@ -769,8 +765,6 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
|
||||
@@ -788,8 +782,6 @@ github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmn
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-goquery/goquery v1.0.1 h1:kpchVA1LdOFWdRpkDPESVdlb1JQI6ixsJ5MiNUITO7U=
|
||||
github.com/go-goquery/goquery v1.0.1/go.mod h1:W5s8OWbqWf6lG0LkXWBeh7U1Y/X5XTI0Br65MHF8uJk=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
@@ -809,8 +801,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
|
||||
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
|
||||
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
|
||||
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
@@ -891,7 +883,6 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
|
||||
github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -1185,10 +1176,6 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -1250,9 +1237,6 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91
|
||||
golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -1312,11 +1296,6 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
||||
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
@@ -1366,10 +1345,6 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -1451,14 +1426,8 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
|
||||
@@ -1467,11 +1436,6 @@ golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U=
|
||||
golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -1488,10 +1452,6 @@ golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
@@ -1566,8 +1526,6 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -1647,8 +1605,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/
|
||||
google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI=
|
||||
google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0=
|
||||
google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg=
|
||||
google.golang.org/api v0.240.0 h1:PxG3AA2UIqT1ofIzWV2COM3j3JagKTKSwy7L6RHNXNU=
|
||||
google.golang.org/api v0.240.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
|
||||
google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo=
|
||||
google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
|
||||
@@ -75,7 +75,7 @@ func toolsetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
}()
|
||||
|
||||
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset %q does not exist", toolsetName)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -111,7 +111,7 @@ func toolGetHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
metric.WithAttributes(attribute.String("toolbox.operation.status", status)),
|
||||
)
|
||||
}()
|
||||
tool, ok := s.ResourceMgr.GetTool(toolName)
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -156,7 +156,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
}()
|
||||
|
||||
tool, ok := s.ResourceMgr.GetTool(toolName)
|
||||
tool, ok := s.tools[toolName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
@@ -167,7 +167,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
// Tool authentication
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
for _, aS := range s.ResourceMgr.GetAuthServiceMap() {
|
||||
for _, aS := range s.authServices {
|
||||
claims, err := aS.GetClaimsFromHeader(ctx, r.Header)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
|
||||
@@ -147,23 +147,14 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
||||
t.Fatalf("unable to setup otel: %s", err)
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(fakeVersionString)
|
||||
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := NewResourceManager(nil, nil, tools, toolsets)
|
||||
|
||||
server := Server{
|
||||
version: fakeVersionString,
|
||||
logger: testLogger,
|
||||
instrumentation: instrumentation,
|
||||
sseManager: sseManager,
|
||||
ResourceMgr: resourceManager,
|
||||
}
|
||||
|
||||
server := Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: tools, toolsets: toolsets}
|
||||
var r chi.Router
|
||||
switch router {
|
||||
case "api":
|
||||
|
||||
@@ -53,10 +53,6 @@ type ServerConfig struct {
|
||||
TelemetryServiceName string
|
||||
// Stdio indicates if Toolbox is listening via MCP stdio.
|
||||
Stdio bool
|
||||
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
||||
DisableReload bool
|
||||
// UI indicates if Toolbox UI endpoints (/ui) are available
|
||||
UI bool
|
||||
}
|
||||
|
||||
type logFormat string
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package telemetry
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -479,12 +479,12 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
|
||||
}
|
||||
return v, res, err
|
||||
default:
|
||||
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||
toolset, ok := s.toolsets[toolsetName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), body)
|
||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.tools, body)
|
||||
return "", res, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -693,22 +693,14 @@ func TestStdioSession(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(fakeVersionString)
|
||||
instrumentation, err := CreateTelemetryInstrumentation(fakeVersionString)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create custom metrics: %s", err)
|
||||
}
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := NewResourceManager(nil, nil, toolsMap, toolsets)
|
||||
|
||||
server := &Server{
|
||||
version: fakeVersionString,
|
||||
logger: testLogger,
|
||||
instrumentation: instrumentation,
|
||||
sseManager: sseManager,
|
||||
ResourceMgr: resourceManager,
|
||||
}
|
||||
server := &Server{version: fakeVersionString, logger: testLogger, instrumentation: instrumentation, sseManager: sseManager, tools: toolsMap, toolsets: toolsets}
|
||||
|
||||
in := bufio.NewReader(pr)
|
||||
stdioSession := NewStdioSession(server, in, pw)
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -30,7 +29,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
@@ -44,225 +42,26 @@ type Server struct {
|
||||
listener net.Listener
|
||||
root chi.Router
|
||||
logger log.Logger
|
||||
instrumentation *telemetry.Instrumentation
|
||||
instrumentation *Instrumentation
|
||||
sseManager *sseManager
|
||||
ResourceMgr *ResourceManager
|
||||
}
|
||||
|
||||
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
|
||||
type ResourceManager struct {
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
}
|
||||
|
||||
func NewResourceManager(
|
||||
sourcesMap map[string]sources.Source,
|
||||
authServicesMap map[string]auth.AuthService,
|
||||
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
|
||||
) *ResourceManager {
|
||||
resourceMgr := &ResourceManager{
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
}
|
||||
|
||||
return resourceMgr
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetSource(sourceName string) (sources.Source, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
source, ok := r.sources[sourceName]
|
||||
return source, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthService, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
authService, ok := r.authServices[authServiceName]
|
||||
return authService, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
tool, ok := r.tools[toolName]
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetToolset(toolsetName string) (tools.Toolset, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
toolset, ok := r.toolsets[toolsetName]
|
||||
return toolset, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.sources = sourcesMap
|
||||
r.authServices = authServicesMap
|
||||
r.tools = toolsMap
|
||||
r.toolsets = toolsetsMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.authServices
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.tools
|
||||
}
|
||||
|
||||
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]sources.Source,
|
||||
map[string]auth.AuthService,
|
||||
map[string]tools.Tool,
|
||||
map[string]tools.Toolset,
|
||||
error,
|
||||
) {
|
||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
l, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// initialize and validate the sources from configs
|
||||
sourcesMap := make(map[string]sources.Source)
|
||||
for name, sc := range cfg.SourceConfigs {
|
||||
s, err := func() (sources.Source, error) {
|
||||
childCtx, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/source/init",
|
||||
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("source_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
|
||||
}
|
||||
return s, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
|
||||
|
||||
// initialize and validate the auth services from configs
|
||||
authServicesMap := make(map[string]auth.AuthService)
|
||||
for name, sc := range cfg.AuthServiceConfigs {
|
||||
a, err := func() (auth.AuthService, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/auth/init",
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
a, err := sc.Initialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
|
||||
}
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
t, err := func() (tools.Tool, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/tool/init",
|
||||
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
||||
trace.WithAttributes(attribute.String("tool_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
t, err := tc.Initialize(sourcesMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
|
||||
}
|
||||
return t, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d tools.", len(toolsMap)))
|
||||
|
||||
// create a default toolset that contains all tools
|
||||
allToolNames := make([]string, 0, len(toolsMap))
|
||||
for name := range toolsMap {
|
||||
allToolNames = append(allToolNames, name)
|
||||
}
|
||||
if cfg.ToolsetConfigs == nil {
|
||||
cfg.ToolsetConfigs = make(ToolsetConfigs)
|
||||
}
|
||||
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
||||
|
||||
// initialize and validate the toolsets from configs
|
||||
toolsetsMap := make(map[string]tools.Toolset)
|
||||
for name, tc := range cfg.ToolsetConfigs {
|
||||
t, err := func() (tools.Toolset, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/toolset/init",
|
||||
trace.WithAttributes(attribute.String("toolset_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
t, err := tc.Initialize(cfg.Version, toolsMap)
|
||||
if err != nil {
|
||||
return tools.Toolset{}, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
|
||||
}
|
||||
return t, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d toolsets.", len(toolsetsMap)))
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, nil
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||
func NewServer(ctx context.Context, cfg ServerConfig, l log.Logger) (*Server, error) {
|
||||
instrumentation, err := CreateTelemetryInstrumentation(cfg.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||
}
|
||||
|
||||
ctx, span := instrumentation.Tracer.Start(ctx, "toolbox/server/init")
|
||||
defer span.End()
|
||||
|
||||
l, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
||||
|
||||
// set up http serving
|
||||
r := chi.NewRouter()
|
||||
@@ -298,18 +97,116 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
||||
r.Use(httplog.RequestLogger(httpLogger))
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize configs: %w", err)
|
||||
// initialize and validate the sources from configs
|
||||
sourcesMap := make(map[string]sources.Source)
|
||||
for name, sc := range cfg.SourceConfigs {
|
||||
s, err := func() (sources.Source, error) {
|
||||
childCtx, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/source/init",
|
||||
trace.WithAttributes(attribute.String("source_kind", sc.SourceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("source_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
s, err := sc.Initialize(childCtx, instrumentation.Tracer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize source %q: %w", name, err)
|
||||
}
|
||||
return s, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d sources.", len(sourcesMap)))
|
||||
|
||||
// initialize and validate the auth services from configs
|
||||
authServicesMap := make(map[string]auth.AuthService)
|
||||
for name, sc := range cfg.AuthServiceConfigs {
|
||||
a, err := func() (auth.AuthService, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/auth/init",
|
||||
trace.WithAttributes(attribute.String("auth_kind", sc.AuthServiceConfigKind())),
|
||||
trace.WithAttributes(attribute.String("auth_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
a, err := sc.Initialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize auth service %q: %w", name, err)
|
||||
}
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices.", len(authServicesMap)))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
t, err := func() (tools.Tool, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/tool/init",
|
||||
trace.WithAttributes(attribute.String("tool_kind", tc.ToolConfigKind())),
|
||||
trace.WithAttributes(attribute.String("tool_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
t, err := tc.Initialize(sourcesMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize tool %q: %w", name, err)
|
||||
}
|
||||
return t, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d tools.", len(toolsMap)))
|
||||
|
||||
// create a default toolset that contains all tools
|
||||
allToolNames := make([]string, 0, len(toolsMap))
|
||||
for name := range toolsMap {
|
||||
allToolNames = append(allToolNames, name)
|
||||
}
|
||||
if cfg.ToolsetConfigs == nil {
|
||||
cfg.ToolsetConfigs = make(ToolsetConfigs)
|
||||
}
|
||||
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
||||
|
||||
// initialize and validate the toolsets from configs
|
||||
toolsetsMap := make(map[string]tools.Toolset)
|
||||
for name, tc := range cfg.ToolsetConfigs {
|
||||
t, err := func() (tools.Toolset, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/toolset/init",
|
||||
trace.WithAttributes(attribute.String("toolset_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
t, err := tc.Initialize(cfg.Version, toolsMap)
|
||||
if err != nil {
|
||||
return tools.Toolset{}, fmt.Errorf("unable to initialize toolset %q: %w", name, err)
|
||||
}
|
||||
return t, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d toolsets.", len(toolsetsMap)))
|
||||
|
||||
addr := net.JoinHostPort(cfg.Address, strconv.Itoa(cfg.Port))
|
||||
srv := &http.Server{Addr: addr, Handler: r}
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap)
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
srv: srv,
|
||||
@@ -317,7 +214,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
logger: l,
|
||||
instrumentation: instrumentation,
|
||||
sseManager: sseManager,
|
||||
ResourceMgr: resourceManager,
|
||||
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
}
|
||||
// control plane
|
||||
apiR, err := apiRouter(s)
|
||||
@@ -330,13 +231,6 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
return nil, err
|
||||
}
|
||||
r.Mount("/mcp", mcpR)
|
||||
if cfg.UI {
|
||||
webR, err := webRouter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Mount("/ui", webR)
|
||||
}
|
||||
// default endpoint for validating server is running
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("🧰 Hello, World! 🧰"))
|
||||
|
||||
@@ -23,16 +23,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
func TestServe(t *testing.T) {
|
||||
@@ -61,16 +54,8 @@ func TestServe(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.WithLogger(ctx, testLogger)
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||
|
||||
s, err := server.NewServer(ctx, cfg)
|
||||
s, err := server.NewServer(ctx, cfg, testLogger)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to initialize server: %v", err)
|
||||
}
|
||||
@@ -108,67 +93,3 @@ func TestServe(t *testing.T) {
|
||||
t.Fatalf("version missing from output: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateServer(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("error setting up logger: %s", err)
|
||||
}
|
||||
|
||||
addr, port := "127.0.0.1", 5000
|
||||
cfg := server.ServerConfig{
|
||||
Version: "0.0.0",
|
||||
Address: addr,
|
||||
Port: port,
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(cfg.Version)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||
|
||||
s, err := server.NewServer(ctx, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("error setting up server: %s", err)
|
||||
}
|
||||
|
||||
newSources := map[string]sources.Source{
|
||||
"example-source": &alloydbpg.Source{
|
||||
Name: "example-alloydb-source",
|
||||
Kind: "alloydb-postgres",
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
Name: "example-toolset", Tools: []*tools.Tool{},
|
||||
},
|
||||
}
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets)
|
||||
if err != nil {
|
||||
t.Errorf("error updating server: %s", err)
|
||||
}
|
||||
|
||||
gotSource, _ := s.ResourceMgr.GetSource("example-source")
|
||||
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
|
||||
t.Errorf("error updating server, sources (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
gotAuthService, _ := s.ResourceMgr.GetAuthService("example-auth")
|
||||
if diff := cmp.Diff(gotAuthService, newAuth["example-auth"]); diff != "" {
|
||||
t.Errorf("error updating server, authServices (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
gotTool, _ := s.ResourceMgr.GetTool("example-tool")
|
||||
if diff := cmp.Diff(gotTool, newTools["example-tool"]); diff != "" {
|
||||
t.Errorf("error updating server, tools (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
gotToolset, _ := s.ResourceMgr.GetToolset("example-toolset")
|
||||
if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"]); diff != "" {
|
||||
t.Errorf("error updating server, toolset (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 57 KiB |
@@ -1,231 +0,0 @@
|
||||
body {
|
||||
display: flex;
|
||||
height: 100vh;
|
||||
margin: 0;
|
||||
font-family: 'Trebuchet MS';
|
||||
background-color: #f8f9fa;
|
||||
}
|
||||
|
||||
.left-nav {
|
||||
flex: 0 0 200px;
|
||||
background-color: #fff;
|
||||
box-shadow: 4px 0px 12px rgba(0, 0, 0, 0.15);
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 15px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.second-nav {
|
||||
flex: 0 0 250px;
|
||||
background-color: #fff;
|
||||
box-shadow: 4px 0px 12px rgba(0, 0, 0, 0.15);
|
||||
z-index: 10;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 15px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.nav-logo {
|
||||
width: 100%;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.nav-logo img {
|
||||
max-width: 100%;
|
||||
height: auto;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.left-nav ul {
|
||||
font-family: 'Verdana';
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.left-nav ul li {
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
|
||||
.left-nav ul li a {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 12px;
|
||||
text-decoration: none;
|
||||
color: #333;
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.left-nav ul li a:hover {
|
||||
background-color: #e9e9e9;
|
||||
border-radius: 35px;
|
||||
}
|
||||
|
||||
.left-nav ul li a.active {
|
||||
background-color: #d0d0d0;
|
||||
font-weight: bold;
|
||||
border-radius: 35px;
|
||||
}
|
||||
|
||||
.main-content-area {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.top-bar {
|
||||
background-color: #fff;
|
||||
padding: 30px 30px;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
align-items: center;
|
||||
border-bottom: 1px solid #eee;
|
||||
}
|
||||
|
||||
.top-bar span {
|
||||
font-weight: bold;
|
||||
font-size: 1.2em;
|
||||
font-family: 'Trebuchet MS';
|
||||
color: #333;
|
||||
}
|
||||
|
||||
.content {
|
||||
padding: 20px;
|
||||
flex-grow: 1;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.tool-button {
|
||||
/* --- Mimicking .left-nav ul li a styles --- */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 12px;
|
||||
text-decoration: none;
|
||||
color: #333;
|
||||
background-color: transparent; /* Reset default button background */
|
||||
border: none; /* Reset default button border */
|
||||
border-radius: 0; /* Start with sharp corners */
|
||||
width: 100%;
|
||||
text-align: left;
|
||||
cursor: pointer;
|
||||
font-family: inherit; /* Inherit 'Verdana' from ul */
|
||||
font-size: inherit; /* Inherit font size */
|
||||
|
||||
/* Transition to match a tags */
|
||||
transition: background-color 0.1s ease-in-out, border-radius 0.1s ease-in-out;
|
||||
}
|
||||
|
||||
.tool-button:hover {
|
||||
/* Mimic .left-nav ul li a:hover */
|
||||
background-color: #e9e9e9;
|
||||
border-radius: 35px;
|
||||
}
|
||||
|
||||
.tool-button:focus {
|
||||
outline: none;
|
||||
/* Optional: You might want a subtle focus indicator, e.g.
|
||||
box-shadow: 0 0 0 2px rgba(208, 208, 208, 0.5); */
|
||||
}
|
||||
|
||||
.tool-button.active {
|
||||
/* Mimic .left-nav ul li a.active */
|
||||
background-color: #d0d0d0;
|
||||
font-weight: bold;
|
||||
border-radius: 35px;
|
||||
}
|
||||
|
||||
.tool-button.active:hover {
|
||||
background-color: #d0d0d0; /* Keep active color when hovered */
|
||||
}
|
||||
#secondary-panel-content ul {
|
||||
list-style: none; /* This is the main property to remove bullet points */
|
||||
padding: 0; /* Reset default browser padding on the ul */
|
||||
margin: 0; /* Reset default browser margin on the ul */
|
||||
width: 100%; /* Ensure it takes full width if needed */
|
||||
}
|
||||
|
||||
.tool-details-grid {
|
||||
display: grid;
|
||||
grid-template-columns: 2fr 1fr; /* 2/3 for info, 1/3 for params */
|
||||
gap: 20px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.tool-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
}
|
||||
|
||||
.tool-params {
|
||||
background-color: #f9f9f9;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #ddd;
|
||||
}
|
||||
|
||||
.tool-box {
|
||||
background-color: #fff;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
border: 1px solid #eee;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
||||
}
|
||||
|
||||
.tool-box h5 {
|
||||
margin-top: 0;
|
||||
margin-bottom: 8px;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.param-item {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.param-item label {
|
||||
display: block;
|
||||
margin-bottom: 4px;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.param-item input[type="text"],
|
||||
.param-item input[type="number"],
|
||||
.param-item select {
|
||||
width: calc(100% - 12px);
|
||||
padding: 6px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.run-tool-btn {
|
||||
padding: 10px 15px;
|
||||
background-color: #4285f4; /* Google Blue */
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 1em;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.run-tool-btn:hover {
|
||||
background-color: #357ae8;
|
||||
}
|
||||
|
||||
.tool-response {
|
||||
width: 100%;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.tool-response textarea {
|
||||
width: calc(100% - 24px);
|
||||
min-height: 150px;
|
||||
padding: 12px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Toolbox UI</title>
|
||||
<link rel="stylesheet" href="/ui/css/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<nav class="left-nav">
|
||||
<div class="nav-logo">
|
||||
<img src="/ui/assets/mcptoolboxlogo.png" alt="App Logo">
|
||||
</div>
|
||||
<ul>
|
||||
<li><a href="/ui/sources">Sources</a></li>
|
||||
<li><a href="/ui/authservices">Auth Services</a></li>
|
||||
<li><a href="/ui/tools">Tools</a></li>
|
||||
<li><a href="/ui/toolsets">Toolsets</a></li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<div class="main-content-area", align="center">
|
||||
<div class="top-bar">
|
||||
</div>
|
||||
<main class="content">
|
||||
<h1>Homepage</h1>
|
||||
</main>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,255 +0,0 @@
|
||||
// helper function to create form inputs for parameters
|
||||
function createParamInput(param, toolId) {
|
||||
const paramItem = document.createElement('div');
|
||||
paramItem.className = 'param-item';
|
||||
|
||||
const label = document.createElement('label');
|
||||
const inputId = `param-${toolId}-${param.name}`;
|
||||
label.setAttribute('for', inputId);
|
||||
label.textContent = param.label;
|
||||
paramItem.appendChild(label);
|
||||
|
||||
let inputElement;
|
||||
if (param.type === 'select') {
|
||||
inputElement = document.createElement('select');
|
||||
param.options.forEach(optionValue => {
|
||||
const option = document.createElement('option');
|
||||
option.value = optionValue;
|
||||
option.textContent = optionValue;
|
||||
if (optionValue === param.defaultValue) {
|
||||
option.selected = true;
|
||||
}
|
||||
inputElement.appendChild(option);
|
||||
});
|
||||
} else if (param.type === 'textarea') { // Handle textarea for arrays
|
||||
inputElement = document.createElement('textarea');
|
||||
inputElement.rows = 3;
|
||||
inputElement.value = param.defaultValue || '';
|
||||
if (param.valueType && param.valueType.startsWith('array')) {
|
||||
inputElement.placeholder = 'E.g., ["item1", "item2"] or [1, 2, 3]';
|
||||
}
|
||||
} else { // text, number, etc.
|
||||
inputElement = document.createElement('input');
|
||||
inputElement.type = param.type;
|
||||
inputElement.value = param.defaultValue || '';
|
||||
}
|
||||
// Common properties
|
||||
inputElement.id = inputId;
|
||||
inputElement.name = param.name;
|
||||
paramItem.appendChild(inputElement);
|
||||
return paramItem;
|
||||
}
|
||||
|
||||
function displayResults(results, responseArea, prettify) {
|
||||
if (results === null || results === undefined) {
|
||||
// responseArea.value = ''; // Keep placeholder or old error message
|
||||
return;
|
||||
}
|
||||
try {
|
||||
if (prettify) {
|
||||
responseArea.value = JSON.stringify(JSON.parse(results.result), null, 2);
|
||||
} else {
|
||||
responseArea.value = JSON.stringify(JSON.parse(results.result));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error stringifying results:", error);
|
||||
responseArea.value = "Error displaying results.";
|
||||
}
|
||||
}
|
||||
|
||||
// function to run the tool (calls API version of endpoint)
|
||||
async function handleRunTool(toolId, form, responseArea, parameters, prettifyCheckbox, updateLastResults) {
|
||||
responseArea.value = 'Running tool...';
|
||||
updateLastResults(null); // Clear last results before new run
|
||||
const formData = new FormData(form);
|
||||
const typedParams = {};
|
||||
|
||||
for (const param of parameters) {
|
||||
const rawValue = formData.get(param.name);
|
||||
|
||||
if (rawValue === null || rawValue === undefined || rawValue === '') {
|
||||
if (param.required) {
|
||||
console.warn(`Required parameter ${param.name} is missing.`);
|
||||
}
|
||||
typedParams[param.name] = null;
|
||||
continue;
|
||||
}
|
||||
|
||||
const valueType = param.valueType;
|
||||
|
||||
try {
|
||||
if (valueType && valueType.startsWith('array<')) {
|
||||
const elementType = valueType.substring(6, valueType.length - 1);
|
||||
let parsedArray;
|
||||
try {
|
||||
parsedArray = JSON.parse(rawValue);
|
||||
} catch (e) {
|
||||
throw new Error(`Invalid JSON format for ${param.name}. ${e.message}`);
|
||||
}
|
||||
|
||||
if (!Array.isArray(parsedArray)) {
|
||||
throw new Error(`Input for ${param.name} must be a JSON array (e.g., ["a", "b"]).`);
|
||||
}
|
||||
|
||||
if (elementType === 'number') {
|
||||
typedParams[param.name] = parsedArray.map((item, index) => {
|
||||
const num = Number(item);
|
||||
if (isNaN(num)) {
|
||||
throw new Error(`Invalid number "${item}" found in array for ${param.name} at index ${index}.`);
|
||||
}
|
||||
return num;
|
||||
});
|
||||
} else if (elementType === 'boolean') {
|
||||
typedParams[param.name] = parsedArray.map(item => item === true || String(item).toLowerCase() === 'true');
|
||||
} else { // string or other types
|
||||
typedParams[param.name] = parsedArray;
|
||||
}
|
||||
} else {
|
||||
switch (valueType) {
|
||||
case 'number':
|
||||
const num = Number(rawValue);
|
||||
if (isNaN(num)) {
|
||||
throw new Error(`Invalid number input for ${param.name}: ${rawValue}`);
|
||||
}
|
||||
typedParams[param.name] = num;
|
||||
break;
|
||||
case 'boolean':
|
||||
typedParams[param.name] = rawValue === 'true';
|
||||
break;
|
||||
case 'string':
|
||||
default:
|
||||
typedParams[param.name] = rawValue;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing parameter:', param.name, error);
|
||||
responseArea.value = `Error for ${param.name}: ${error.message}`;
|
||||
return; // Stop processing
|
||||
}
|
||||
}
|
||||
|
||||
console.log('Running tool:', toolId, 'with typed params:', typedParams);
|
||||
try {
|
||||
const response = await fetch(`/api/tool/${toolId}/invoke`, {
|
||||
method: 'POST',
|
||||
headers: {'Content-Type': 'application/json'},
|
||||
body: JSON.stringify(typedParams)
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.text();
|
||||
throw new Error(`HTTP error ${response.status}: ${errorBody}`);
|
||||
}
|
||||
const results = await response.json();
|
||||
updateLastResults(results); // Update the stored results
|
||||
displayResults(results, responseArea, prettifyCheckbox.checked); // Display formatted results
|
||||
} catch (error) {
|
||||
console.error('Error running tool:', error);
|
||||
responseArea.value = `Error: ${error.message}`;
|
||||
updateLastResults(null); // Clear results on error
|
||||
}
|
||||
}
|
||||
|
||||
// renders the tool display area
|
||||
export function renderToolInterface(tool, containerElement) {
|
||||
containerElement.innerHTML = '';
|
||||
const toolId = tool.id;
|
||||
|
||||
let lastResults = null; // Store the most recent successful result object
|
||||
|
||||
// Function to update lastResults, closure to keep it private to this scope
|
||||
const updateLastResults = (newResults) => {
|
||||
lastResults = newResults;
|
||||
};
|
||||
|
||||
const gridContainer = document.createElement('div');
|
||||
gridContainer.className = 'tool-details-grid';
|
||||
|
||||
const toolInfoContainer = document.createElement('div');
|
||||
toolInfoContainer.className = 'tool-info';
|
||||
|
||||
const nameBox = document.createElement('div');
|
||||
nameBox.className = 'tool-box tool-name';
|
||||
nameBox.innerHTML = `<h5>Name:</h5><p>${tool.name}</p>`;
|
||||
toolInfoContainer.appendChild(nameBox);
|
||||
|
||||
const descBox = document.createElement('div');
|
||||
descBox.className = 'tool-box tool-description';
|
||||
descBox.innerHTML = `<h5>Description:</h5><p>${tool.description}</p>`;
|
||||
toolInfoContainer.appendChild(descBox);
|
||||
|
||||
gridContainer.appendChild(toolInfoContainer);
|
||||
|
||||
const paramsContainer = document.createElement('div');
|
||||
paramsContainer.className = 'tool-params tool-box';
|
||||
paramsContainer.innerHTML = '<h5>Parameters:</h5>';
|
||||
const form = document.createElement('form');
|
||||
form.id = `tool-params-form-${toolId}`;
|
||||
|
||||
tool.parameters.forEach(param => {
|
||||
form.appendChild(createParamInput(param, toolId));
|
||||
});
|
||||
paramsContainer.appendChild(form);
|
||||
|
||||
const runButton = document.createElement('button');
|
||||
runButton.className = 'run-tool-btn';
|
||||
runButton.textContent = 'Run Tool';
|
||||
paramsContainer.appendChild(runButton);
|
||||
|
||||
gridContainer.appendChild(paramsContainer);
|
||||
containerElement.appendChild(gridContainer);
|
||||
|
||||
// Response Area
|
||||
const responseContainer = document.createElement('div');
|
||||
responseContainer.className = 'tool-response tool-box';
|
||||
|
||||
const responseHeader = document.createElement('h5');
|
||||
responseHeader.textContent = 'Response:';
|
||||
responseContainer.appendChild(responseHeader);
|
||||
|
||||
// Prettify Checkbox
|
||||
const prettifyId = `prettify-${toolId}`;
|
||||
const prettifyLabel = document.createElement('label');
|
||||
prettifyLabel.setAttribute('for', prettifyId);
|
||||
prettifyLabel.textContent = 'Prettify JSON';
|
||||
prettifyLabel.style.display = 'inline-block';
|
||||
prettifyLabel.style.marginLeft = '10px';
|
||||
prettifyLabel.style.verticalAlign = 'middle';
|
||||
prettifyLabel.style.cursor = 'pointer';
|
||||
|
||||
const prettifyCheckbox = document.createElement('input');
|
||||
prettifyCheckbox.type = 'checkbox';
|
||||
prettifyCheckbox.id = prettifyId;
|
||||
prettifyCheckbox.checked = true; // Default to pretty
|
||||
prettifyCheckbox.style.verticalAlign = 'middle';
|
||||
prettifyCheckbox.style.cursor = 'pointer';
|
||||
|
||||
const prettifyDiv = document.createElement('div');
|
||||
prettifyDiv.style.marginBottom = '5px';
|
||||
prettifyDiv.appendChild(prettifyCheckbox);
|
||||
prettifyDiv.appendChild(prettifyLabel);
|
||||
responseContainer.appendChild(prettifyDiv);
|
||||
|
||||
const responseAreaId = `tool-response-area-${toolId}`;
|
||||
const responseArea = document.createElement('textarea');
|
||||
responseArea.id = responseAreaId;
|
||||
responseArea.readOnly = true;
|
||||
responseArea.placeholder = 'Results will appear here...';
|
||||
responseArea.style.width = 'calc(100% - 12px)';
|
||||
responseArea.rows = 10;
|
||||
responseContainer.appendChild(responseArea);
|
||||
|
||||
containerElement.appendChild(responseContainer);
|
||||
|
||||
// Event Listeners
|
||||
prettifyCheckbox.addEventListener('change', () => {
|
||||
if (lastResults) {
|
||||
displayResults(lastResults, responseArea, prettifyCheckbox.checked);
|
||||
}
|
||||
});
|
||||
|
||||
runButton.addEventListener('click', (event) => {
|
||||
event.preventDefault();
|
||||
handleRunTool(toolId, form, responseArea, tool.parameters, prettifyCheckbox, updateLastResults);
|
||||
});
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
import { renderToolInterface } from "./toolDisplay.js";
|
||||
|
||||
document.addEventListener('DOMContentLoaded', () => {
|
||||
const toolDisplayArea = document.getElementById('tool-display-area');
|
||||
const secondaryPanelContent = document.getElementById('secondary-panel-content');
|
||||
|
||||
if (!secondaryPanelContent || !toolDisplayArea) {
|
||||
console.error('Required DOM elements not found.');
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches the list of tools from the API and renders them in the secondary panel.
|
||||
*/
|
||||
async function loadTools() {
|
||||
secondaryPanelContent.innerHTML = '<p>Fetching tools...</p>';
|
||||
try {
|
||||
// This endpoint should list tools, the structure you provided seems to be for a single tool
|
||||
const response = await fetch('/api/toolset');
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const apiResponse = await response.json();
|
||||
renderToolList(apiResponse);
|
||||
} catch (error) {
|
||||
console.error('Failed to load tools:', error);
|
||||
secondaryPanelContent.innerHTML = '<p class="error">Failed to load tools. Please try again later.</p>';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders the list of tools in the secondary navigation panel.
|
||||
* @param {object} apiResponse - The parsed JSON response from the API.
|
||||
*/
|
||||
function renderToolList(apiResponse) {
|
||||
secondaryPanelContent.innerHTML = '';
|
||||
|
||||
if (!apiResponse || typeof apiResponse.tools !== 'object' || apiResponse.tools === null) {
|
||||
console.error('Error: Expected an object with a "tools" property, but received:', apiResponse);
|
||||
secondaryPanelContent.textContent = 'Error: Invalid response format from toolset API.';
|
||||
return;
|
||||
}
|
||||
|
||||
const toolsObject = apiResponse.tools;
|
||||
const toolNames = Object.keys(toolsObject);
|
||||
|
||||
if (toolNames.length === 0) {
|
||||
secondaryPanelContent.textContent = 'No tools found.';
|
||||
return;
|
||||
}
|
||||
|
||||
const ul = document.createElement('ul');
|
||||
toolNames.forEach(toolName => {
|
||||
const li = document.createElement('li');
|
||||
const button = document.createElement('button');
|
||||
button.textContent = toolName;
|
||||
button.dataset.toolname = toolName;
|
||||
button.classList.add('tool-button');
|
||||
button.addEventListener('click', handleToolClick);
|
||||
li.appendChild(button);
|
||||
ul.appendChild(li);
|
||||
});
|
||||
secondaryPanelContent.appendChild(ul);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the click event on a tool button in the secondary panel.
|
||||
* @param {MouseEvent} event - The click event.
|
||||
*/
|
||||
function handleToolClick(event) {
|
||||
const toolName = event.target.dataset.toolname;
|
||||
if (toolName) {
|
||||
const currentActive = secondaryPanelContent.querySelector('.tool-button.active');
|
||||
if (currentActive) {
|
||||
currentActive.classList.remove('active');
|
||||
}
|
||||
event.target.classList.add('active');
|
||||
fetchToolDetails(toolName);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches details for a specific tool from the API and renders the UI.
|
||||
* @param {string} toolName - The name of the tool.
|
||||
*/
|
||||
async function fetchToolDetails(toolName) {
|
||||
toolDisplayArea.innerHTML = '<p>Loading tool details...</p>';
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/tool/${encodeURIComponent(toolName)}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
const apiResponse = await response.json();
|
||||
|
||||
if (!apiResponse.tools || !apiResponse.tools[toolName]) {
|
||||
throw new Error(`Tool "${toolName}" data not found in API response.`);
|
||||
}
|
||||
const toolObject = apiResponse.tools[toolName];
|
||||
|
||||
const toolInterfaceData = {
|
||||
id: toolName,
|
||||
name: toolName,
|
||||
description: toolObject.description || "No description provided.",
|
||||
parameters: (toolObject.parameters || []).map(param => {
|
||||
let inputType = 'text'; // Default HTML input type
|
||||
let options;
|
||||
const apiType = param.type ? param.type.toLowerCase() : 'string';
|
||||
let valueType = 'string'; // Data type for API payload
|
||||
let label = param.description || param.name;
|
||||
|
||||
if (apiType === 'integer' || apiType === 'number') {
|
||||
inputType = 'number';
|
||||
valueType = 'number';
|
||||
} else if (apiType === 'boolean') {
|
||||
inputType = 'select';
|
||||
options = ['true', 'false'];
|
||||
valueType = 'boolean';
|
||||
} else if (apiType === 'array') {
|
||||
inputType = 'textarea'; // Use textarea for array inputs
|
||||
const itemType = param.items && param.items.type ? param.items.type.toLowerCase() : 'string';
|
||||
valueType = `array<${itemType}>`;
|
||||
label += ' (JSON Array string)'; // Hint to the user
|
||||
} else if (param.enum && Array.isArray(param.enum)) {
|
||||
inputType = 'select';
|
||||
options = param.enum;
|
||||
valueType = 'string';
|
||||
}
|
||||
console.log(param.name, inputType, label, apiType, valueType)
|
||||
|
||||
return {
|
||||
name: param.name,
|
||||
type: inputType, // For HTML input element type
|
||||
valueType: valueType, // For API request payload type
|
||||
label: label,
|
||||
required: param.required || false,
|
||||
options: options,
|
||||
defaultValue: param.default,
|
||||
};
|
||||
})
|
||||
};
|
||||
|
||||
console.log("Transformed toolInterfaceData:", toolInterfaceData);
|
||||
|
||||
renderToolInterface(toolInterfaceData, toolDisplayArea);
|
||||
|
||||
} catch (error) {
|
||||
console.error(`Failed to load details for tool "${toolName}":`, error);
|
||||
toolDisplayArea.innerHTML = `<p class="error">Failed to load details for ${toolName}. ${error.message}</p>`;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Initial load of tools list
|
||||
loadTools();
|
||||
});
|
||||
@@ -1,40 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Page Three</title>
|
||||
<link rel="stylesheet" href="css/style.css">
|
||||
</head>
|
||||
<body>
|
||||
<nav class="left-nav">
|
||||
<div class="nav-logo">
|
||||
<img src="/ui/assets/mcptoolboxlogo.png" alt="App Logo">
|
||||
</div>
|
||||
<ul>
|
||||
<li><a href="/ui/sources">Sources</a></li>
|
||||
<li><a href="/ui/authservices">Auth Services</a></li>
|
||||
<li><a href="/ui/tools" class="active">Tools</a></li>
|
||||
<li><a href="/ui/toolsets">Toolsets</a></li>
|
||||
</ul>
|
||||
</nav>
|
||||
|
||||
<aside class="second-nav">
|
||||
<h4>My Tools</h4>
|
||||
<div id="secondary-panel-content">
|
||||
<p>Fetching tools...</p>
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
<div class="main-content-area">
|
||||
<div class="top-bar">
|
||||
<span>Toolbox UI Inspector</span>
|
||||
</div>
|
||||
<main class="content" id="tool-display-area">
|
||||
<h1>Welcome to the Main Page</h1>
|
||||
<p>This is the main content area. Click a tab on the left to navigate.</p>
|
||||
</main>
|
||||
</div>
|
||||
<script type="module" src="js/tools.js"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,53 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
//go:embed all:static
|
||||
var staticContent embed.FS
|
||||
|
||||
// webRouter creates a router that represents the routes under /ui
|
||||
func webRouter() (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.StripSlashes)
|
||||
|
||||
// direct routes for html pages to provide clean URLs
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/index.html") })
|
||||
r.Get("/tools", func(w http.ResponseWriter, r *http.Request) { serveHTML(w, r, "static/tools.html") })
|
||||
|
||||
// handler for all other static files/assets
|
||||
staticFS, _ := fs.Sub(staticContent, "static")
|
||||
r.Handle("/*", http.StripPrefix("/ui", http.FileServer(http.FS(staticFS))))
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func serveHTML(w http.ResponseWriter, r *http.Request, filepath string) {
|
||||
file, err := staticContent.Open(filepath)
|
||||
if err != nil {
|
||||
http.Error(w, "File not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Error reading file: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
fileInfo, err := file.Stat()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
http.ServeContent(w, r, fileInfo.Name(), fileInfo.ModTime(), bytes.NewReader(fileBytes))
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-goquery/goquery"
|
||||
)
|
||||
|
||||
func TestWebEndpoint(t *testing.T) {
|
||||
router, err := webRouter()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create webRouter: %v", err)
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(router)
|
||||
defer ts.Close()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
wantStatus int
|
||||
wantContentType string
|
||||
wantPageTitle string
|
||||
}{
|
||||
{
|
||||
name: "web index page GET",
|
||||
method: http.MethodGet,
|
||||
path: "/",
|
||||
wantStatus: http.StatusOK,
|
||||
wantContentType: "text/html; charset=utf-8",
|
||||
wantPageTitle: "Toolbox UI",
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
url := ts.URL + tc.path
|
||||
req, err := http.NewRequest(tc.method, url, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create request: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read response body: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != tc.wantStatus {
|
||||
t.Errorf("Unexpected status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatus, string(body))
|
||||
}
|
||||
|
||||
if contentType := resp.Header.Get("Content-Type"); contentType != tc.wantContentType {
|
||||
t.Errorf("Unexpected Content-Type header: got %s, want %s", contentType, tc.wantContentType)
|
||||
}
|
||||
|
||||
doc, err := goquery.NewDocumentFromReader(strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse HTML: %v", err)
|
||||
}
|
||||
gotPageTitle := doc.Find("title").Text()
|
||||
|
||||
if gotPageTitle != tc.wantPageTitle {
|
||||
t.Errorf("Unexpected page title: got %q, want %q", gotPageTitle, tc.wantPageTitle)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,9 @@
|
||||
package testutils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
@@ -45,63 +42,3 @@ func ContextWithNewLogger() (context.Context, error) {
|
||||
}
|
||||
return util.WithLogger(ctx, logger), nil
|
||||
}
|
||||
|
||||
// WaitForString waits until the server logs a single line that matches the provided regex.
|
||||
// returns the output of whatever the server sent so far.
|
||||
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {
|
||||
in := bufio.NewReader(pr)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// read lines in background, sending result of each read over a channel
|
||||
// this allows us to use in.ReadString without blocking
|
||||
type result struct {
|
||||
s string
|
||||
err error
|
||||
}
|
||||
output := make(chan result)
|
||||
go func() {
|
||||
defer close(output)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// if the context is canceled, the orig thread will send back the error
|
||||
// so we can just exit the goroutine here
|
||||
return
|
||||
default:
|
||||
// otherwise read a line from the output
|
||||
s, err := in.ReadString('\n')
|
||||
if err != nil {
|
||||
output <- result{err: err}
|
||||
return
|
||||
}
|
||||
output <- result{s: s}
|
||||
// if that last string matched, exit the goroutine
|
||||
if re.MatchString(s) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// collect the output until the ctx is canceled, an error was hit,
|
||||
// or match was found (which is indicated the channel is closed)
|
||||
var sb strings.Builder
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// if ctx is done, return that error
|
||||
return sb.String(), ctx.Err()
|
||||
case o, ok := <-output:
|
||||
if !ok {
|
||||
// match was found!
|
||||
return sb.String(), nil
|
||||
}
|
||||
if o.err != nil {
|
||||
// error was found!
|
||||
return sb.String(), o.err
|
||||
}
|
||||
sb.WriteString(o.s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,8 +26,6 @@ import (
|
||||
)
|
||||
|
||||
const kind string = "bigquery-get-dataset-info"
|
||||
const projectKey string = "project"
|
||||
const datasetKey string = "dataset"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
@@ -80,9 +78,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryClient().Project(), "The Google Cloud project ID containing the dataset.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to get metadata information.")
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
datasetParameter := tools.NewStringParameter("dataset", "The dataset to get metadata information.")
|
||||
parameters := tools.Parameters{datasetParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -119,18 +116,14 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
datasetId, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
dsHandle := t.Client.DatasetInProject(projectId, datasetId)
|
||||
dsHandle := t.Client.Dataset(datasetId)
|
||||
|
||||
metadata, err := dsHandle.Metadata(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -26,9 +26,6 @@ import (
|
||||
)
|
||||
|
||||
const kind string = "bigquery-get-table-info"
|
||||
const projectKey string = "project"
|
||||
const datasetKey string = "dataset"
|
||||
const tableKey string = "table"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
@@ -81,10 +78,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryClient().Project(), "The Google Cloud project ID containing the dataset and table.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The table's parent dataset.")
|
||||
tableParameter := tools.NewStringParameter(tableKey, "The table to get metadata information.")
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter, tableParameter}
|
||||
datasetParameter := tools.NewStringParameter("dataset", "The table's parent dataset.")
|
||||
tableParameter := tools.NewStringParameter("table", "The table to get metadata information.")
|
||||
parameters := tools.Parameters{datasetParameter, tableParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -121,28 +117,22 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
sliceParams := params.AsSlice()
|
||||
datasetId, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
tableId, ok := sliceParams[1].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[1])
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
tableId, ok := mapParams[tableKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
}
|
||||
|
||||
dsHandle := t.Client.DatasetInProject(projectId, datasetId)
|
||||
dsHandle := t.Client.Dataset(datasetId)
|
||||
tableHandle := dsHandle.Table(tableId)
|
||||
|
||||
metadata, err := tableHandle.Metadata(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", projectId, datasetId, tableId, err)
|
||||
return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", t.Client.Project(), datasetId, tableId, err)
|
||||
}
|
||||
|
||||
return []any{metadata}, nil
|
||||
|
||||
@@ -27,7 +27,6 @@ import (
|
||||
)
|
||||
|
||||
const kind string = "bigquery-list-dataset-ids"
|
||||
const projectKey string = "project"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
@@ -80,9 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryClient().Project(), "The Google Cloud project to list dataset ids.")
|
||||
|
||||
parameters := tools.Parameters{projectParameter}
|
||||
parameters := tools.Parameters{}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -119,13 +116,7 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
}
|
||||
datasetIterator := t.Client.Datasets(ctx)
|
||||
datasetIterator.ProjectID = projectId
|
||||
|
||||
var datasetIds []any
|
||||
for {
|
||||
|
||||
@@ -27,8 +27,6 @@ import (
|
||||
)
|
||||
|
||||
const kind string = "bigquery-list-table-ids"
|
||||
const projectKey string = "project"
|
||||
const datasetKey string = "dataset"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
@@ -81,9 +79,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryClient().Project(), "The Google Cloud project ID containing the dataset.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to list table ids.")
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
datasetParameter := tools.NewStringParameter("dataset", "The dataset to list table ids.")
|
||||
parameters := tools.Parameters{datasetParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -120,18 +117,14 @@ type Tool struct {
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
datasetId, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
dsHandle := t.Client.DatasetInProject(projectId, datasetId)
|
||||
dsHandle := t.Client.Dataset(datasetId)
|
||||
|
||||
var tableIds []any
|
||||
tableIterator := dsHandle.Tables(ctx)
|
||||
|
||||
@@ -319,6 +319,13 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.Default != nil {
|
||||
intD, ok := a.Default.(uint64)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("default value fail to assert to type int64")
|
||||
}
|
||||
a.Default = int(intD)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
@@ -411,6 +418,7 @@ type ParameterMcpManifest struct {
|
||||
type CommonParameter struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Default any `yaml:"default"`
|
||||
Desc string `yaml:"description" validate:"required"`
|
||||
AuthServices []ParamAuthService `yaml:"authServices"`
|
||||
AuthSources []ParamAuthService `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
@@ -426,6 +434,30 @@ func (p *CommonParameter) GetType() string {
|
||||
return p.Type
|
||||
}
|
||||
|
||||
func (p *CommonParameter) GetDefault() any {
|
||||
return p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the Parameter.
|
||||
func (p *CommonParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
var required bool
|
||||
if p.Default == nil {
|
||||
required = true
|
||||
}
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Required: required,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest for the Parameter.
|
||||
func (p *CommonParameter) McpManifest() ParameterMcpManifest {
|
||||
return ParameterMcpManifest{
|
||||
@@ -468,10 +500,10 @@ func NewStringParameterWithDefault(name string, defaultV, desc string) *StringPa
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeString,
|
||||
Default: defaultV,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Default: &defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,7 +524,6 @@ var _ Parameter = &StringParameter{}
|
||||
// StringParameter is a parameter representing the "string" type.
|
||||
type StringParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *string `yaml:"default"`
|
||||
}
|
||||
|
||||
// Parse casts the value "v" as a "string".
|
||||
@@ -503,35 +534,10 @@ func (p *StringParameter) Parse(v any) (any, error) {
|
||||
}
|
||||
return newV, nil
|
||||
}
|
||||
|
||||
func (p *StringParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *StringParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the StringParameter.
|
||||
func (p *StringParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
required := p.Default == nil
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Required: required,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
// NewIntParameter is a convenience function for initializing a IntParameter.
|
||||
func NewIntParameter(name string, desc string) *IntParameter {
|
||||
return &IntParameter{
|
||||
@@ -545,15 +551,15 @@ func NewIntParameter(name string, desc string) *IntParameter {
|
||||
}
|
||||
|
||||
// NewIntParameterWithDefault is a convenience function for initializing a IntParameter with default value.
|
||||
func NewIntParameterWithDefault(name string, defaultV int, desc string) *IntParameter {
|
||||
func NewIntParameterWithDefault(name string, defaultV any, desc string) *IntParameter {
|
||||
return &IntParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeInt,
|
||||
Default: defaultV,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Default: &defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -574,7 +580,6 @@ var _ Parameter = &IntParameter{}
|
||||
// IntParameter is a parameter representing the "int" type.
|
||||
type IntParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *int `yaml:"default"`
|
||||
}
|
||||
|
||||
func (p *IntParameter) Parse(v any) (any, error) {
|
||||
@@ -602,30 +607,6 @@ func (p *IntParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *IntParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the IntParameter.
|
||||
func (p *IntParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
required := p.Default == nil
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Required: required,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
// NewFloatParameter is a convenience function for initializing a FloatParameter.
|
||||
func NewFloatParameter(name string, desc string) *FloatParameter {
|
||||
return &FloatParameter{
|
||||
@@ -644,10 +625,10 @@ func NewFloatParameterWithDefault(name string, defaultV float64, desc string) *F
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeFloat,
|
||||
Default: defaultV,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Default: &defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -668,7 +649,6 @@ var _ Parameter = &FloatParameter{}
|
||||
// FloatParameter is a parameter representing the "float" type.
|
||||
type FloatParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *float64 `yaml:"default"`
|
||||
}
|
||||
|
||||
func (p *FloatParameter) Parse(v any) (any, error) {
|
||||
@@ -694,30 +674,6 @@ func (p *FloatParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *FloatParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the FloatParameter.
|
||||
func (p *FloatParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
required := p.Default == nil
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Required: required,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
|
||||
func NewBooleanParameter(name string, desc string) *BooleanParameter {
|
||||
return &BooleanParameter{
|
||||
@@ -736,10 +692,10 @@ func NewBooleanParameterWithDefault(name string, defaultV bool, desc string) *Bo
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeBool,
|
||||
Default: defaultV,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Default: &defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,7 +716,6 @@ var _ Parameter = &BooleanParameter{}
|
||||
// BooleanParameter is a parameter representing the "boolean" type.
|
||||
type BooleanParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *bool `yaml:"default"`
|
||||
}
|
||||
|
||||
func (p *BooleanParameter) Parse(v any) (any, error) {
|
||||
@@ -775,30 +730,6 @@ func (p *BooleanParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *BooleanParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the BooleanParameter.
|
||||
func (p *BooleanParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
required := p.Default == nil
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
Required: required,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
}
|
||||
}
|
||||
|
||||
// NewArrayParameter is a convenience function for initializing a ArrayParameter.
|
||||
func NewArrayParameter(name string, desc string, items Parameter) *ArrayParameter {
|
||||
return &ArrayParameter{
|
||||
@@ -813,16 +744,16 @@ func NewArrayParameter(name string, desc string, items Parameter) *ArrayParamete
|
||||
}
|
||||
|
||||
// NewArrayParameterWithDefault is a convenience function for initializing a ArrayParameter with default value.
|
||||
func NewArrayParameterWithDefault(name string, defaultV []any, desc string, items Parameter) *ArrayParameter {
|
||||
func NewArrayParameterWithDefault(name string, defaultV any, desc string, items Parameter) *ArrayParameter {
|
||||
return &ArrayParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: typeArray,
|
||||
Default: defaultV,
|
||||
Desc: desc,
|
||||
AuthServices: nil,
|
||||
},
|
||||
Items: items,
|
||||
Default: &defaultV,
|
||||
Items: items,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -844,21 +775,18 @@ var _ Parameter = &ArrayParameter{}
|
||||
// ArrayParameter is a parameter representing the "array" type.
|
||||
type ArrayParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *[]any `yaml:"default"`
|
||||
Items Parameter `yaml:"items"`
|
||||
}
|
||||
|
||||
func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
var rawItem struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *[]any `yaml:"default"`
|
||||
Items util.DelayedUnmarshaler `yaml:"items"`
|
||||
}
|
||||
if err := unmarshal(&rawItem); err != nil {
|
||||
return err
|
||||
}
|
||||
p.CommonParameter = rawItem.CommonParameter
|
||||
p.Default = rawItem.Default
|
||||
i, err := parseParamFromDelayedUnmarshaler(ctx, &rawItem.Items)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse 'items' field: %w", err)
|
||||
@@ -891,13 +819,6 @@ func (p *ArrayParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *ArrayParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the ArrayParameter.
|
||||
func (p *ArrayParameter) Manifest() ParameterManifest {
|
||||
// only list ParamAuthService names (without fields) in manifest
|
||||
@@ -906,8 +827,11 @@ func (p *ArrayParameter) Manifest() ParameterManifest {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
items := p.Items.Manifest()
|
||||
required := p.Default == nil
|
||||
items.Required = required
|
||||
var required bool
|
||||
if p.Default == nil {
|
||||
required = true
|
||||
items.Required = true
|
||||
}
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: p.Type,
|
||||
|
||||
@@ -187,7 +187,7 @@ func TestParametersMarshal(t *testing.T) {
|
||||
{
|
||||
"name": "my_array",
|
||||
"type": "array",
|
||||
"default": []any{"foo", "bar"},
|
||||
"default": `["foo", "bar"]`,
|
||||
"description": "this param is an array of strings",
|
||||
"items": map[string]string{
|
||||
"name": "my_string",
|
||||
@@ -197,7 +197,7 @@ func TestParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameterWithDefault("my_array", []any{"foo", "bar"}, "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
|
||||
tools.NewArrayParameterWithDefault("my_array", `["foo", "bar"]`, "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -206,7 +206,7 @@ func TestParametersMarshal(t *testing.T) {
|
||||
{
|
||||
"name": "my_array",
|
||||
"type": "array",
|
||||
"default": []any{1.0, 1.1},
|
||||
"default": "[1.0, 1.1]",
|
||||
"description": "this param is an array of floats",
|
||||
"items": map[string]string{
|
||||
"name": "my_float",
|
||||
@@ -216,7 +216,7 @@ func TestParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewArrayParameterWithDefault("my_array", []any{1.0, 1.1}, "this param is an array of floats", tools.NewFloatParameter("my_float", "float item")),
|
||||
tools.NewArrayParameterWithDefault("my_array", "[1.0, 1.1]", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item")),
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -993,14 +993,14 @@ func TestParamManifest(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "array default",
|
||||
in: tools.NewArrayParameterWithDefault("foo-array", []any{"foo", "bar"}, "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
in: tools.NewArrayParameterWithDefault("foo-array", `["foo", "bar"]`, "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-array",
|
||||
Type: "array",
|
||||
Required: false,
|
||||
Description: "bar",
|
||||
AuthServices: []string{},
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1073,7 +1073,7 @@ func TestMcpManifest(t *testing.T) {
|
||||
tools.NewStringParameter("foo-string2", "bar"),
|
||||
tools.NewIntParameterWithDefault("foo-int", 1, "bar"),
|
||||
tools.NewIntParameter("foo-int2", "bar"),
|
||||
tools.NewArrayParameterWithDefault("foo-array", []any{"hello", "world"}, "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
tools.NewArrayParameterWithDefault("foo-array", []string{"hello", "world"}, "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
},
|
||||
want: tools.McpToolsSchema{
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/go-playground/validator/v10"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
)
|
||||
|
||||
// DecodeJSON decodes a given reader into an interface using the json decoder.
|
||||
@@ -99,25 +98,10 @@ func WithLogger(ctx context.Context, logger log.Logger) context.Context {
|
||||
return context.WithValue(ctx, loggerKey, logger)
|
||||
}
|
||||
|
||||
// LoggerFromContext retrieves the logger or return an error
|
||||
// LoggerFromContext retreives the logger or return an error
|
||||
func LoggerFromContext(ctx context.Context) (log.Logger, error) {
|
||||
if logger, ok := ctx.Value(loggerKey).(log.Logger); ok {
|
||||
return logger, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unable to retrieve logger")
|
||||
}
|
||||
|
||||
const instrumentationKey contextKey = "instrumentation"
|
||||
|
||||
// WithInstrumentation adds an instrumentation into the context as a value
|
||||
func WithInstrumentation(ctx context.Context, instrumentation *telemetry.Instrumentation) context.Context {
|
||||
return context.WithValue(ctx, instrumentationKey, instrumentation)
|
||||
}
|
||||
|
||||
// InstrumentationFromContext retrieves the instrumentation or return an error
|
||||
func InstrumentationFromContext(ctx context.Context) (*telemetry.Instrumentation, error) {
|
||||
if instrumentation, ok := ctx.Value(instrumentationKey).(*telemetry.Instrumentation); ok {
|
||||
return instrumentation, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unable to retrieve instrumentation")
|
||||
}
|
||||
|
||||
@@ -28,60 +28,59 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
AlloyDBAINLSourceKind = "alloydb-postgres"
|
||||
AlloyDBAINLToolKind = "alloydb-ai-nl"
|
||||
AlloyDBAINLProject = os.Getenv("ALLOYDB_AI_NL_PROJECT")
|
||||
AlloyDBAINLRegion = os.Getenv("ALLOYDB_AI_NL_REGION")
|
||||
AlloyDBAINLCluster = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
|
||||
AlloyDBAINLInstance = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
|
||||
AlloyDBAINLDatabase = os.Getenv("ALLOYDB_AI_NL_DATABASE")
|
||||
AlloyDBAINLUser = os.Getenv("ALLOYDB_AI_NL_USER")
|
||||
AlloyDBAINLPass = os.Getenv("ALLOYDB_AI_NL_PASS")
|
||||
ALLOYDB_AI_NL_SOURCE_KIND = "alloydb-postgres"
|
||||
ALLOYDB_AI_NL_TOOL_KIND = "alloydb-ai-nl"
|
||||
ALLOYDB_AI_NL_PROJECT = os.Getenv("ALLOYDB_AI_NL_PROJECT")
|
||||
ALLOYDB_AI_NL_REGION = os.Getenv("ALLOYDB_AI_NL_REGION")
|
||||
ALLOYDB_AI_NL_CLUSTER = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
|
||||
ALLOYDB_AI_NL_INSTANCE = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
|
||||
ALLOYDB_AI_NL_DATABASE = os.Getenv("ALLOYDB_AI_NL_DATABASE")
|
||||
ALLOYDB_AI_NL_USER = os.Getenv("ALLOYDB_AI_NL_USER")
|
||||
ALLOYDB_AI_NL_PASS = os.Getenv("ALLOYDB_AI_NL_PASS")
|
||||
)
|
||||
|
||||
func getAlloyDBAINLVars(t *testing.T) map[string]any {
|
||||
func getAlloyDBAiNlVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case AlloyDBAINLProject:
|
||||
case ALLOYDB_AI_NL_PROJECT:
|
||||
t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set")
|
||||
case AlloyDBAINLRegion:
|
||||
case ALLOYDB_AI_NL_REGION:
|
||||
t.Fatal("'ALLOYDB_AI_NL_REGION' not set")
|
||||
case AlloyDBAINLCluster:
|
||||
case ALLOYDB_AI_NL_CLUSTER:
|
||||
t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set")
|
||||
case AlloyDBAINLInstance:
|
||||
case ALLOYDB_AI_NL_INSTANCE:
|
||||
t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set")
|
||||
case AlloyDBAINLDatabase:
|
||||
case ALLOYDB_AI_NL_DATABASE:
|
||||
t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set")
|
||||
case AlloyDBAINLUser:
|
||||
case ALLOYDB_AI_NL_USER:
|
||||
t.Fatal("'ALLOYDB_AI_NL_USER' not set")
|
||||
case AlloyDBAINLPass:
|
||||
case ALLOYDB_AI_NL_PASS:
|
||||
t.Fatal("'ALLOYDB_AI_NL_PASS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": AlloyDBAINLSourceKind,
|
||||
"project": AlloyDBAINLProject,
|
||||
"cluster": AlloyDBAINLCluster,
|
||||
"instance": AlloyDBAINLInstance,
|
||||
"region": AlloyDBAINLRegion,
|
||||
"database": AlloyDBAINLDatabase,
|
||||
"user": AlloyDBAINLUser,
|
||||
"password": AlloyDBAINLPass,
|
||||
"kind": ALLOYDB_AI_NL_SOURCE_KIND,
|
||||
"project": ALLOYDB_AI_NL_PROJECT,
|
||||
"cluster": ALLOYDB_AI_NL_CLUSTER,
|
||||
"instance": ALLOYDB_AI_NL_INSTANCE,
|
||||
"region": ALLOYDB_AI_NL_REGION,
|
||||
"database": ALLOYDB_AI_NL_DATABASE,
|
||||
"user": ALLOYDB_AI_NL_USER,
|
||||
"password": ALLOYDB_AI_NL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlloyDBAINLToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getAlloyDBAINLVars(t)
|
||||
func TestAlloyDBAiNlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getAlloyDBAiNlVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := getAINLToolsConfig(sourceConfig)
|
||||
toolsFile := getAiNlToolsConfig(sourceConfig)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -91,18 +90,18 @@ func TestAlloyDBAINLToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runAINLToolGetTest(t)
|
||||
runAINLToolInvokeTest(t)
|
||||
runAINLMCPToolCallMethod(t)
|
||||
runAiNlToolGetTest(t)
|
||||
runAiNlToolInvokeTest(t)
|
||||
runAiNlMCPToolCallMethod(t)
|
||||
}
|
||||
|
||||
func runAINLToolGetTest(t *testing.T) {
|
||||
func runAiNlToolGetTest(t *testing.T) {
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
@@ -157,7 +156,7 @@ func runAINLToolGetTest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func runAINLToolInvokeTest(t *testing.T) {
|
||||
func runAiNlToolInvokeTest(t *testing.T) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
@@ -277,7 +276,7 @@ func runAINLToolInvokeTest(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
@@ -291,13 +290,13 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": AlloyDBAINLToolKind,
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"nlConfig": "my_nl_config",
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": AlloyDBAINLToolKind,
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test authenticated parameters.",
|
||||
"nlConfig": "my_nl_config",
|
||||
@@ -316,7 +315,7 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
},
|
||||
},
|
||||
"my-auth-required-tool": map[string]any{
|
||||
"kind": AlloyDBAINLToolKind,
|
||||
"kind": ALLOYDB_AI_NL_TOOL_KIND,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test auth required invocation.",
|
||||
"nlConfig": "my_nl_config",
|
||||
@@ -330,7 +329,7 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
func runAINLMCPToolCallMethod(t *testing.T) {
|
||||
func runAiNlMCPToolCallMethod(t *testing.T) {
|
||||
sessionId := tests.RunInitialize(t, "2024-11-05")
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
|
||||
@@ -26,66 +26,65 @@ import (
|
||||
|
||||
"cloud.google.com/go/alloydbconn"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
AlloyDBPostgresSourceKind = "alloydb-postgres"
|
||||
AlloyDBPostgresToolKind = "postgres-sql"
|
||||
AlloyDBPostgresProject = os.Getenv("ALLOYDB_POSTGRES_PROJECT")
|
||||
AlloyDBPostgresRegion = os.Getenv("ALLOYDB_POSTGRES_REGION")
|
||||
AlloyDBPostgresCluster = os.Getenv("ALLOYDB_POSTGRES_CLUSTER")
|
||||
AlloyDBPostgresInstance = os.Getenv("ALLOYDB_POSTGRES_INSTANCE")
|
||||
AlloyDBPostgresDatabase = os.Getenv("ALLOYDB_POSTGRES_DATABASE")
|
||||
AlloyDBPostgresUser = os.Getenv("ALLOYDB_POSTGRES_USER")
|
||||
AlloyDBPostgresPass = os.Getenv("ALLOYDB_POSTGRES_PASS")
|
||||
ALLOYDB_POSTGRES_SOURCE_KIND = "alloydb-postgres"
|
||||
ALLOYDB_POSTGRES_TOOL_KIND = "postgres-sql"
|
||||
ALLOYDB_POSTGRES_PROJECT = os.Getenv("ALLOYDB_POSTGRES_PROJECT")
|
||||
ALLOYDB_POSTGRES_REGION = os.Getenv("ALLOYDB_POSTGRES_REGION")
|
||||
ALLOYDB_POSTGRES_CLUSTER = os.Getenv("ALLOYDB_POSTGRES_CLUSTER")
|
||||
ALLOYDB_POSTGRES_INSTANCE = os.Getenv("ALLOYDB_POSTGRES_INSTANCE")
|
||||
ALLOYDB_POSTGRES_DATABASE = os.Getenv("ALLOYDB_POSTGRES_DATABASE")
|
||||
ALLOYDB_POSTGRES_USER = os.Getenv("ALLOYDB_POSTGRES_USER")
|
||||
ALLOYDB_POSTGRES_PASS = os.Getenv("ALLOYDB_POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func getAlloyDBPgVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case AlloyDBPostgresProject:
|
||||
case ALLOYDB_POSTGRES_PROJECT:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_PROJECT' not set")
|
||||
case AlloyDBPostgresRegion:
|
||||
case ALLOYDB_POSTGRES_REGION:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_REGION' not set")
|
||||
case AlloyDBPostgresCluster:
|
||||
case ALLOYDB_POSTGRES_CLUSTER:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_CLUSTER' not set")
|
||||
case AlloyDBPostgresInstance:
|
||||
case ALLOYDB_POSTGRES_INSTANCE:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_INSTANCE' not set")
|
||||
case AlloyDBPostgresDatabase:
|
||||
case ALLOYDB_POSTGRES_DATABASE:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_DATABASE' not set")
|
||||
case AlloyDBPostgresUser:
|
||||
case ALLOYDB_POSTGRES_USER:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_USER' not set")
|
||||
case AlloyDBPostgresPass:
|
||||
case ALLOYDB_POSTGRES_PASS:
|
||||
t.Fatal("'ALLOYDB_POSTGRES_PASS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": AlloyDBPostgresSourceKind,
|
||||
"project": AlloyDBPostgresProject,
|
||||
"cluster": AlloyDBPostgresCluster,
|
||||
"instance": AlloyDBPostgresInstance,
|
||||
"region": AlloyDBPostgresRegion,
|
||||
"database": AlloyDBPostgresDatabase,
|
||||
"user": AlloyDBPostgresUser,
|
||||
"password": AlloyDBPostgresPass,
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"user": ALLOYDB_POSTGRES_USER,
|
||||
"password": ALLOYDB_POSTGRES_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from alloydb_pg.go
|
||||
func getAlloyDBDialOpts(ipType string) ([]alloydbconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
func getAlloyDBDialOpts(ip_type string) ([]alloydbconn.DialOption, error) {
|
||||
switch strings.ToLower(ip_type) {
|
||||
case "private":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from alloydb_pg.go
|
||||
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
config, err := pgxpool.ParseConfig(dsn)
|
||||
@@ -94,7 +93,7 @@ func initAlloyDBPgConnectionPool(project, region, cluster, instance, ipType, use
|
||||
}
|
||||
|
||||
// Create a new dialer with options
|
||||
dialOpts, err := getAlloyDBDialOpts(ipType)
|
||||
dialOpts, err := getAlloyDBDialOpts(ip_type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -124,7 +123,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initAlloyDBPgConnectionPool(AlloyDBPostgresProject, AlloyDBPostgresRegion, AlloyDBPostgresCluster, AlloyDBPostgresInstance, "public", AlloyDBPostgresUser, AlloyDBPostgresPass, AlloyDBPostgresDatabase)
|
||||
pool, err := initAlloyDBPgConnectionPool(ALLOYDB_POSTGRES_PROJECT, ALLOYDB_POSTGRES_REGION, ALLOYDB_POSTGRES_CLUSTER, ALLOYDB_POSTGRES_INSTANCE, "public", ALLOYDB_POSTGRES_USER, ALLOYDB_POSTGRES_PASS, ALLOYDB_POSTGRES_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
|
||||
}
|
||||
@@ -135,20 +134,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, ALLOYDB_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -158,7 +157,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -167,8 +166,8 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
@@ -194,7 +193,7 @@ func TestAlloyDBPgIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, AlloyDBPostgresToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
@@ -206,35 +205,35 @@ func TestAlloyDBPgIpConnection(t *testing.T) {
|
||||
func TestAlloyDBPgIAMConnection(t *testing.T) {
|
||||
getAlloyDBPgVars(t)
|
||||
// service account email used for IAM should trim the suffix
|
||||
serviceAccountEmail := strings.TrimSuffix(tests.ServiceAccountEmail, ".gserviceaccount.com")
|
||||
serviceAccountEmail := strings.TrimSuffix(tests.SERVICE_ACCOUNT_EMAIL, ".gserviceaccount.com")
|
||||
|
||||
noPassSourceConfig := map[string]any{
|
||||
"kind": AlloyDBPostgresSourceKind,
|
||||
"project": AlloyDBPostgresProject,
|
||||
"cluster": AlloyDBPostgresCluster,
|
||||
"instance": AlloyDBPostgresInstance,
|
||||
"region": AlloyDBPostgresRegion,
|
||||
"database": AlloyDBPostgresDatabase,
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"user": serviceAccountEmail,
|
||||
}
|
||||
|
||||
noUserSourceConfig := map[string]any{
|
||||
"kind": AlloyDBPostgresSourceKind,
|
||||
"project": AlloyDBPostgresProject,
|
||||
"cluster": AlloyDBPostgresCluster,
|
||||
"instance": AlloyDBPostgresInstance,
|
||||
"region": AlloyDBPostgresRegion,
|
||||
"database": AlloyDBPostgresDatabase,
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
"password": "random",
|
||||
}
|
||||
|
||||
noUserNoPassSourceConfig := map[string]any{
|
||||
"kind": AlloyDBPostgresSourceKind,
|
||||
"project": AlloyDBPostgresProject,
|
||||
"cluster": AlloyDBPostgresCluster,
|
||||
"instance": AlloyDBPostgresInstance,
|
||||
"region": AlloyDBPostgresRegion,
|
||||
"database": AlloyDBPostgresDatabase,
|
||||
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
|
||||
"project": ALLOYDB_POSTGRES_PROJECT,
|
||||
"cluster": ALLOYDB_POSTGRES_CLUSTER,
|
||||
"instance": ALLOYDB_POSTGRES_INSTANCE,
|
||||
"region": ALLOYDB_POSTGRES_REGION,
|
||||
"database": ALLOYDB_POSTGRES_DATABASE,
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
@@ -259,7 +258,7 @@ func TestAlloyDBPgIAMConnection(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, AlloyDBPostgresToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
if tc.isErr {
|
||||
return
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"google.golang.org/api/idtoken"
|
||||
)
|
||||
|
||||
var ServiceAccountEmail = os.Getenv("SERVICE_ACCOUNT_EMAIL")
|
||||
var SERVICE_ACCOUNT_EMAIL = os.Getenv("SERVICE_ACCOUNT_EMAIL")
|
||||
var ClientId = os.Getenv("CLIENT_ID")
|
||||
|
||||
// GetGoogleIdToken retrieve and return the Google ID token
|
||||
|
||||
@@ -29,7 +29,6 @@ import (
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/api/googleapi"
|
||||
@@ -38,20 +37,20 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
BigquerySourceKind = "bigquery"
|
||||
BigqueryToolKind = "bigquery-sql"
|
||||
BigqueryProject = os.Getenv("BIGQUERY_PROJECT")
|
||||
BIGQUERY_SOURCE_KIND = "bigquery"
|
||||
BIGQUERY_TOOL_KIND = "bigquery-sql"
|
||||
BIGQUERY_PROJECT = os.Getenv("BIGQUERY_PROJECT")
|
||||
)
|
||||
|
||||
func getBigQueryVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case BigqueryProject:
|
||||
case BIGQUERY_PROJECT:
|
||||
t.Fatal("'BIGQUERY_PROJECT' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": BigquerySourceKind,
|
||||
"project": BigqueryProject,
|
||||
"kind": BIGQUERY_SOURCE_KIND,
|
||||
"project": BIGQUERY_PROJECT,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +76,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
client, err := initBigQueryConnection(BIGQUERY_PROJECT)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
@@ -86,36 +85,36 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
tableNameParam := fmt.Sprintf("`%s.%s.%s`",
|
||||
BigqueryProject,
|
||||
BIGQUERY_PROJECT,
|
||||
datasetName,
|
||||
tableName,
|
||||
)
|
||||
tableNameAuth := fmt.Sprintf("`%s.%s.auth_table_%s`",
|
||||
BigqueryProject,
|
||||
BIGQUERY_PROJECT,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
tableNameTemplateParam := fmt.Sprintf("`%s.%s.template_param_table_%s`",
|
||||
BigqueryProject,
|
||||
BIGQUERY_PROJECT,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, createStatement1, insertStatement1, datasetName, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, create_statement1, insert_statement1, datasetName, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := getBigQueryAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupBigQueryTable(t, ctx, client, createStatement2, insertStatement2, datasetName, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := getBigQueryAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupBigQueryTable(t, ctx, client, create_statement2, insert_statement2, datasetName, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BIGQUERY_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BIGQUERY_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -125,7 +124,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -138,8 +137,8 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":"
|
||||
tableInfoWant := "[{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithCreateColArray(`["id INT64", "name STRING", "age INT64"]`),
|
||||
@@ -154,20 +153,18 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-param-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (id INT64, name STRING);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, NULL);`, tableName)
|
||||
INSERT INTO %s (id, name) VALUES (?, ?), (?, ?), (?, ?);`, tableName)
|
||||
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? OR name = ? ORDER BY id;`, tableName)
|
||||
toolStatement2 := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? ORDER BY id;`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: int64(1)}, {Value: "Alice"},
|
||||
{Value: int64(2)}, {Value: "Jane"},
|
||||
{Value: int64(3)}, {Value: "Sid"},
|
||||
{Value: int64(4)},
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryAuthToolInfo returns statements and param of my-auth-tool for bigquery kind
|
||||
@@ -179,7 +176,7 @@ func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigque
|
||||
toolStatement := fmt.Sprintf(`
|
||||
SELECT name FROM %s WHERE email = ?`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: int64(1)}, {Value: "Alice"}, {Value: tests.ServiceAccountEmail},
|
||||
{Value: int64(1)}, {Value: "Alice"}, {Value: tests.SERVICE_ACCOUNT_EMAIL},
|
||||
{Value: int64(2)}, {Value: "Jane"}, {Value: "janedoe@gmail.com"},
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
@@ -192,7 +189,7 @@ func getBigQueryTmplToolStatement() (string, string) {
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
|
||||
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, create_statement, insert_statement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
|
||||
// Create dataset
|
||||
dataset := client.Dataset(datasetName)
|
||||
_, err := dataset.Metadata(ctx)
|
||||
@@ -209,7 +206,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
|
||||
}
|
||||
|
||||
// Create table
|
||||
createJob, err := client.Query(createStatement).Run(ctx)
|
||||
createJob, err := client.Query(create_statement).Run(ctx)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start create table job for %s: %v", tableName, err)
|
||||
@@ -223,7 +220,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
insertQuery := client.Query(insertStatement)
|
||||
insertQuery := client.Query(insert_statement)
|
||||
insertQuery.Parameters = params
|
||||
insertJob, err := insertQuery.Run(ctx)
|
||||
if err != nil {
|
||||
@@ -343,7 +340,7 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
return config
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam string) {
|
||||
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select_1_want, invokeParamWant, tableNameParam string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
@@ -371,7 +368,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -417,7 +414,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
isErr: false,
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with invalid auth token",
|
||||
@@ -502,21 +499,6 @@ func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) {
|
||||
isErr: false,
|
||||
want: datasetWant,
|
||||
},
|
||||
{
|
||||
name: "invoke my-list-dataset-ids-tool with project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\"}", BigqueryProject))),
|
||||
isErr: false,
|
||||
want: datasetWant,
|
||||
},
|
||||
{
|
||||
name: "invoke my-list-dataset-ids-tool with non-existent project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\"}", BigqueryProject, uuid.NewString()))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-list-dataset-ids-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-dataset-ids-tool/invoke",
|
||||
@@ -601,21 +583,6 @@ func runBigQueryGetDatasetInfoToolInvokeTest(t *testing.T, datasetName, datasetI
|
||||
want: datasetInfoWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-get-dataset-info-tool with correct project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject, datasetName))),
|
||||
want: datasetInfoWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-get-dataset-info-tool with non-existent project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-get-dataset-info-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-get-dataset-info-tool without body",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-get-dataset-info-tool/invoke",
|
||||
@@ -736,21 +703,6 @@ func runBigQueryListTableIdsToolInvokeTest(t *testing.T, datasetName, tablename_
|
||||
want: tablename_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-list-table-ids-tool with correct project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\"}", BigqueryProject, datasetName))),
|
||||
want: tablename_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-list-table-ids-tool with non-existent project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-list-table-ids-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-list-table-ids-tool/invoke",
|
||||
@@ -856,21 +808,6 @@ func runBigQueryGetTableInfoToolInvokeTest(t *testing.T, datasetName, tableName,
|
||||
want: tableInfoWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-get-table-info-tool with correct project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject, datasetName, tableName))),
|
||||
want: tableInfoWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-get-table-info-tool with non-existent project",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"project\":\"%s-%s\", \"dataset\":\"%s\", \"table\":\"%s\"}", BigqueryProject, uuid.NewString(), datasetName, tableName))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-get-table-info-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-get-table-info-tool/invoke",
|
||||
|
||||
@@ -29,30 +29,29 @@ import (
|
||||
|
||||
"cloud.google.com/go/bigtable"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
BigtableSourceKind = "bigtable"
|
||||
BigtableToolKind = "bigtable-sql"
|
||||
BigtableProject = os.Getenv("BIGTABLE_PROJECT")
|
||||
BigtableInstance = os.Getenv("BIGTABLE_INSTANCE")
|
||||
BIGTABLE_SOURCE_KIND = "bigtable"
|
||||
BIGTABLE_TOOL_KIND = "bigtable-sql"
|
||||
BIGTABLE_PROJECT = os.Getenv("BIGTABLE_PROJECT")
|
||||
BIGTABLE_INSTANCE = os.Getenv("BIGTABLE_INSTANCE")
|
||||
)
|
||||
|
||||
func getBigtableVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case BigtableProject:
|
||||
case BIGTABLE_PROJECT:
|
||||
t.Fatal("'BIGTABLE_PROJECT' not set")
|
||||
case BigtableInstance:
|
||||
case BIGTABLE_INSTANCE:
|
||||
t.Fatal("'BIGTABLE_INSTANCE' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": BigtableSourceKind,
|
||||
"project": BigtableProject,
|
||||
"instance": BigtableInstance,
|
||||
"kind": BIGTABLE_SOURCE_KIND,
|
||||
"project": BIGTABLE_PROJECT,
|
||||
"instance": BIGTABLE_INSTANCE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,14 +77,13 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
|
||||
// Do not change the shape of statement without checking tests/common_test.go.
|
||||
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
|
||||
paramTestStatement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id OR CAST(cf['name'] AS string) = @name;", tableName)
|
||||
paramTestStatement2 := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id;", tableName)
|
||||
param_test_statement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id OR CAST(cf['name'] AS string) = @name;", tableName)
|
||||
teardownTable1 := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableName, columnFamilyName, muts, rowKeys)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// Do not change the shape of statement without checking tests/common_test.go.
|
||||
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
|
||||
authToolStatement := fmt.Sprintf("SELECT CAST(cf['name'] AS string) as name FROM %s WHERE CAST(cf['email'] AS string) = @email;", tableNameAuth)
|
||||
auth_tool_statement := fmt.Sprintf("SELECT CAST(cf['name'] AS string) as name FROM %s WHERE CAST(cf['email'] AS string) = @email;", tableNameAuth)
|
||||
teardownTable2 := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableNameAuth, columnFamilyName, muts, rowKeys)
|
||||
defer teardownTable2(t)
|
||||
|
||||
@@ -94,7 +92,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
defer teardownTableTmpl(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigtableToolKind, paramTestStatement, paramTestStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BIGTABLE_TOOL_KIND, param_test_statement, auth_tool_statement)
|
||||
toolsFile = addTemplateParamConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -105,7 +103,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -116,9 +114,8 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
// Actual test parameters are set in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L250
|
||||
select1Want := "[{\"$col1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to prepare statement: rpc error: code = InvalidArgument desc = Syntax error: Unexpected identifier \"SELEC\" [at 1:1]"}],"isError":true}}`
|
||||
invokeParamWant, _, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
invokeParamWantNull := `[{"id":4,"name":""}]`
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
@@ -142,7 +139,7 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
muts := []*bigtable.Mutation{}
|
||||
rowKeys := []string{}
|
||||
|
||||
var ids [4][]byte
|
||||
var ids [3][]byte
|
||||
for i := range ids {
|
||||
ids[i] = convertToBytes(i + 1)
|
||||
}
|
||||
@@ -154,7 +151,7 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
// Expected values are defined in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L229-L310
|
||||
"row-01": {
|
||||
"name": []byte("Alice"),
|
||||
"email": []byte(tests.ServiceAccountEmail),
|
||||
"email": []byte(tests.SERVICE_ACCOUNT_EMAIL),
|
||||
"id": ids[0],
|
||||
},
|
||||
"row-02": {
|
||||
@@ -166,10 +163,6 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
|
||||
"name": []byte("Sid"),
|
||||
"id": ids[2],
|
||||
},
|
||||
"row-04": {
|
||||
"name": nil,
|
||||
"id": ids[3],
|
||||
},
|
||||
} {
|
||||
mut := bigtable.NewMutation()
|
||||
for col, v := range mutData {
|
||||
|
||||
@@ -29,54 +29,53 @@ import (
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
CloudSQLMSSQLSourceKind = "cloud-sql-mssql"
|
||||
CloudSQLMSSQLToolKind = "mssql-sql"
|
||||
CloudSQLMSSQLProject = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
|
||||
CloudSQLMSSQLRegion = os.Getenv("CLOUD_SQL_MSSQL_REGION")
|
||||
CloudSQLMSSQLInstance = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
|
||||
CloudSQLMSSQLDatabase = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
|
||||
CloudSQLMSSQLIp = os.Getenv("CLOUD_SQL_MSSQL_IP")
|
||||
CloudSQLMSSQLUser = os.Getenv("CLOUD_SQL_MSSQL_USER")
|
||||
CloudSQLMSSQLPass = os.Getenv("CLOUD_SQL_MSSQL_PASS")
|
||||
CLOUD_SQL_MSSQL_SOURCE_KIND = "cloud-sql-mssql"
|
||||
CLOUD_SQL_MSSQL_TOOL_KIND = "mssql-sql"
|
||||
CLOUD_SQL_MSSQL_PROJECT = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
|
||||
CLOUD_SQL_MSSQL_REGION = os.Getenv("CLOUD_SQL_MSSQL_REGION")
|
||||
CLOUD_SQL_MSSQL_INSTANCE = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
|
||||
CLOUD_SQL_MSSQL_DATABASE = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
|
||||
CLOUD_SQL_MSSQL_IP = os.Getenv("CLOUD_SQL_MSSQL_IP")
|
||||
CLOUD_SQL_MSSQL_USER = os.Getenv("CLOUD_SQL_MSSQL_USER")
|
||||
CLOUD_SQL_MSSQL_PASS = os.Getenv("CLOUD_SQL_MSSQL_PASS")
|
||||
)
|
||||
|
||||
func getCloudSQLMSSQLVars(t *testing.T) map[string]any {
|
||||
func getCloudSQLMssqlVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case CloudSQLMSSQLProject:
|
||||
case CLOUD_SQL_MSSQL_PROJECT:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_PROJECT' not set")
|
||||
case CloudSQLMSSQLRegion:
|
||||
case CLOUD_SQL_MSSQL_REGION:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_REGION' not set")
|
||||
case CloudSQLMSSQLInstance:
|
||||
case CLOUD_SQL_MSSQL_INSTANCE:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_INSTANCE' not set")
|
||||
case CloudSQLMSSQLIp:
|
||||
case CLOUD_SQL_MSSQL_IP:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_IP' not set")
|
||||
case CloudSQLMSSQLDatabase:
|
||||
case CLOUD_SQL_MSSQL_DATABASE:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_DATABASE' not set")
|
||||
case CloudSQLMSSQLUser:
|
||||
case CLOUD_SQL_MSSQL_USER:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_USER' not set")
|
||||
case CloudSQLMSSQLPass:
|
||||
case CLOUD_SQL_MSSQL_PASS:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": CloudSQLMSSQLSourceKind,
|
||||
"project": CloudSQLMSSQLProject,
|
||||
"instance": CloudSQLMSSQLInstance,
|
||||
"ipAddress": CloudSQLMSSQLIp,
|
||||
"region": CloudSQLMSSQLRegion,
|
||||
"database": CloudSQLMSSQLDatabase,
|
||||
"user": CloudSQLMSSQLUser,
|
||||
"password": CloudSQLMSSQLPass,
|
||||
"kind": CLOUD_SQL_MSSQL_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_MSSQL_PROJECT,
|
||||
"instance": CLOUD_SQL_MSSQL_INSTANCE,
|
||||
"ipAddress": CLOUD_SQL_MSSQL_IP,
|
||||
"region": CLOUD_SQL_MSSQL_REGION,
|
||||
"database": CLOUD_SQL_MSSQL_DATABASE,
|
||||
"user": CLOUD_SQL_MSSQL_USER,
|
||||
"password": CLOUD_SQL_MSSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_mssql.go
|
||||
func initCloudSQLMSSQLConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
|
||||
url := &url.URL{
|
||||
@@ -111,14 +110,14 @@ func initCloudSQLMSSQLConnection(project, region, instance, ipAddress, ipType, u
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMSSQLVars(t)
|
||||
func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMssqlVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
db, err := initCloudSQLMSSQLConnection(CloudSQLMSSQLProject, CloudSQLMSSQLRegion, CloudSQLMSSQLInstance, CloudSQLMSSQLIp, "public", CloudSQLMSSQLUser, CloudSQLMSSQLPass, CloudSQLMSSQLDatabase)
|
||||
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
@@ -129,20 +128,20 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMssqlParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetMSSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMsSQLTable(t, ctx, db, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMssqlAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMsSQLTable(t, ctx, db, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMSSQLToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -152,7 +151,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -160,17 +159,17 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMssqlWants()
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
func TestCloudSQLMSSQLIpConnection(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMSSQLVars(t)
|
||||
func TestCloudSQLMssqlIpConnection(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMssqlVars(t)
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
@@ -188,7 +187,7 @@ func TestCloudSQLMSSQLIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLMSSQLToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
|
||||
@@ -28,45 +28,44 @@ import (
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/mysql/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
CloudSQLMySQLSourceKind = "cloud-sql-mysql"
|
||||
CloudSQLMySQLToolKind = "mysql-sql"
|
||||
CloudSQLMySQLProject = os.Getenv("CLOUD_SQL_MYSQL_PROJECT")
|
||||
CloudSQLMySQLRegion = os.Getenv("CLOUD_SQL_MYSQL_REGION")
|
||||
CloudSQLMySQLInstance = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE")
|
||||
CloudSQLMySQLDatabase = os.Getenv("CLOUD_SQL_MYSQL_DATABASE")
|
||||
CloudSQLMySQLUser = os.Getenv("CLOUD_SQL_MYSQL_USER")
|
||||
CloudSQLMySQLPass = os.Getenv("CLOUD_SQL_MYSQL_PASS")
|
||||
CLOUD_SQL_MYSQL_SOURCE_KIND = "cloud-sql-mysql"
|
||||
CLOUD_SQL_MYSQL_TOOL_KIND = "mysql-sql"
|
||||
CLOUD_SQL_MYSQL_PROJECT = os.Getenv("CLOUD_SQL_MYSQL_PROJECT")
|
||||
CLOUD_SQL_MYSQL_REGION = os.Getenv("CLOUD_SQL_MYSQL_REGION")
|
||||
CLOUD_SQL_MYSQL_INSTANCE = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE")
|
||||
CLOUD_SQL_MYSQL_DATABASE = os.Getenv("CLOUD_SQL_MYSQL_DATABASE")
|
||||
CLOUD_SQL_MYSQL_USER = os.Getenv("CLOUD_SQL_MYSQL_USER")
|
||||
CLOUD_SQL_MYSQL_PASS = os.Getenv("CLOUD_SQL_MYSQL_PASS")
|
||||
)
|
||||
|
||||
func getCloudSQLMySQLVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case CloudSQLMySQLProject:
|
||||
case CLOUD_SQL_MYSQL_PROJECT:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_PROJECT' not set")
|
||||
case CloudSQLMySQLRegion:
|
||||
case CLOUD_SQL_MYSQL_REGION:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_REGION' not set")
|
||||
case CloudSQLMySQLInstance:
|
||||
case CLOUD_SQL_MYSQL_INSTANCE:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_INSTANCE' not set")
|
||||
case CloudSQLMySQLDatabase:
|
||||
case CLOUD_SQL_MYSQL_DATABASE:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_DATABASE' not set")
|
||||
case CloudSQLMySQLUser:
|
||||
case CLOUD_SQL_MYSQL_USER:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_USER' not set")
|
||||
case CloudSQLMySQLPass:
|
||||
case CLOUD_SQL_MYSQL_PASS:
|
||||
t.Fatal("'CLOUD_SQL_MYSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": CloudSQLMySQLSourceKind,
|
||||
"project": CloudSQLMySQLProject,
|
||||
"instance": CloudSQLMySQLInstance,
|
||||
"region": CloudSQLMySQLRegion,
|
||||
"database": CloudSQLMySQLDatabase,
|
||||
"user": CloudSQLMySQLUser,
|
||||
"password": CloudSQLMySQLPass,
|
||||
"kind": CLOUD_SQL_MYSQL_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_MYSQL_PROJECT,
|
||||
"instance": CLOUD_SQL_MYSQL_INSTANCE,
|
||||
"region": CLOUD_SQL_MYSQL_REGION,
|
||||
"database": CLOUD_SQL_MYSQL_DATABASE,
|
||||
"user": CLOUD_SQL_MYSQL_USER,
|
||||
"password": CLOUD_SQL_MYSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,14 +97,14 @@ func initCloudSQLMySQLConnectionPool(project, region, instance, ipType, user, pa
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMySQLVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initCloudSQLMySQLConnectionPool(CloudSQLMySQLProject, CloudSQLMySQLRegion, CloudSQLMySQLInstance, "public", CloudSQLMySQLUser, CloudSQLMySQLPass, CloudSQLMySQLDatabase)
|
||||
pool, err := initCloudSQLMySQLConnectionPool(CLOUD_SQL_MYSQL_PROJECT, CLOUD_SQL_MYSQL_REGION, CLOUD_SQL_MYSQL_INSTANCE, "public", CLOUD_SQL_MYSQL_USER, CLOUD_SQL_MYSQL_PASS, CLOUD_SQL_MYSQL_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
@@ -116,20 +115,20 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetMySQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMysqlAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -139,7 +138,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -147,16 +146,16 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMysqlWants()
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
func TestCloudSQLMySQLIpConnection(t *testing.T) {
|
||||
func TestCloudSQLMysqlIpConnection(t *testing.T) {
|
||||
sourceConfig := getCloudSQLMySQLVars(t)
|
||||
|
||||
tcs := []struct {
|
||||
@@ -175,7 +174,7 @@ func TestCloudSQLMySQLIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLMySQLToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
|
||||
@@ -26,46 +26,45 @@ import (
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
CloudSQLPostgresSourceKind = "cloud-sql-postgres"
|
||||
CloudSQLPostgresToolKind = "postgres-sql"
|
||||
CloudSQLPostgresProject = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT")
|
||||
CloudSQLPostgresRegion = os.Getenv("CLOUD_SQL_POSTGRES_REGION")
|
||||
CloudSQLPostgresInstance = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE")
|
||||
CloudSQLPostgresDatabase = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE")
|
||||
CloudSQLPostgresUser = os.Getenv("CLOUD_SQL_POSTGRES_USER")
|
||||
CloudSQLPostgresPass = os.Getenv("CLOUD_SQL_POSTGRES_PASS")
|
||||
CLOUD_SQL_POSTGRES_SOURCE_KIND = "cloud-sql-postgres"
|
||||
CLOUD_SQL_POSTGRES_TOOL_KIND = "postgres-sql"
|
||||
CLOUD_SQL_POSTGRES_PROJECT = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT")
|
||||
CLOUD_SQL_POSTGRES_REGION = os.Getenv("CLOUD_SQL_POSTGRES_REGION")
|
||||
CLOUD_SQL_POSTGRES_INSTANCE = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE")
|
||||
CLOUD_SQL_POSTGRES_DATABASE = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE")
|
||||
CLOUD_SQL_POSTGRES_USER = os.Getenv("CLOUD_SQL_POSTGRES_USER")
|
||||
CLOUD_SQL_POSTGRES_PASS = os.Getenv("CLOUD_SQL_POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func getCloudSQLPgVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case CloudSQLPostgresProject:
|
||||
case CLOUD_SQL_POSTGRES_PROJECT:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_PROJECT' not set")
|
||||
case CloudSQLPostgresRegion:
|
||||
case CLOUD_SQL_POSTGRES_REGION:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_REGION' not set")
|
||||
case CloudSQLPostgresInstance:
|
||||
case CLOUD_SQL_POSTGRES_INSTANCE:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_INSTANCE' not set")
|
||||
case CloudSQLPostgresDatabase:
|
||||
case CLOUD_SQL_POSTGRES_DATABASE:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_DATABASE' not set")
|
||||
case CloudSQLPostgresUser:
|
||||
case CLOUD_SQL_POSTGRES_USER:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_USER' not set")
|
||||
case CloudSQLPostgresPass:
|
||||
case CLOUD_SQL_POSTGRES_PASS:
|
||||
t.Fatal("'CLOUD_SQL_POSTGRES_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": CloudSQLPostgresSourceKind,
|
||||
"project": CloudSQLPostgresProject,
|
||||
"instance": CloudSQLPostgresInstance,
|
||||
"region": CloudSQLPostgresRegion,
|
||||
"database": CloudSQLPostgresDatabase,
|
||||
"user": CloudSQLPostgresUser,
|
||||
"password": CloudSQLPostgresPass,
|
||||
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_POSTGRES_PROJECT,
|
||||
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
|
||||
"region": CLOUD_SQL_POSTGRES_REGION,
|
||||
"database": CLOUD_SQL_POSTGRES_DATABASE,
|
||||
"user": CLOUD_SQL_POSTGRES_USER,
|
||||
"password": CLOUD_SQL_POSTGRES_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,7 +108,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initCloudSQLPgConnectionPool(CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase)
|
||||
pool, err := initCloudSQLPgConnectionPool(CLOUD_SQL_POSTGRES_PROJECT, CLOUD_SQL_POSTGRES_REGION, CLOUD_SQL_POSTGRES_INSTANCE, "public", CLOUD_SQL_POSTGRES_USER, CLOUD_SQL_POSTGRES_PASS, CLOUD_SQL_POSTGRES_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
|
||||
}
|
||||
@@ -120,20 +119,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -143,7 +142,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -152,8 +151,8 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
@@ -179,7 +178,7 @@ func TestCloudSQLPgIpConnection(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
sourceConfig["ipType"] = tc.ipType
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLPostgresToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
t.Fatalf("Connection test failure: %s", err)
|
||||
}
|
||||
@@ -190,32 +189,32 @@ func TestCloudSQLPgIpConnection(t *testing.T) {
|
||||
func TestCloudSQLPgIAMConnection(t *testing.T) {
|
||||
getCloudSQLPgVars(t)
|
||||
// service account email used for IAM should trim the suffix
|
||||
serviceAccountEmail := strings.TrimSuffix(tests.ServiceAccountEmail, ".gserviceaccount.com")
|
||||
serviceAccountEmail := strings.TrimSuffix(tests.SERVICE_ACCOUNT_EMAIL, ".gserviceaccount.com")
|
||||
|
||||
noPassSourceConfig := map[string]any{
|
||||
"kind": CloudSQLPostgresSourceKind,
|
||||
"project": CloudSQLPostgresProject,
|
||||
"instance": CloudSQLPostgresInstance,
|
||||
"region": CloudSQLPostgresRegion,
|
||||
"database": CloudSQLPostgresDatabase,
|
||||
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_POSTGRES_PROJECT,
|
||||
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
|
||||
"region": CLOUD_SQL_POSTGRES_REGION,
|
||||
"database": CLOUD_SQL_POSTGRES_DATABASE,
|
||||
"user": serviceAccountEmail,
|
||||
}
|
||||
|
||||
noUserSourceConfig := map[string]any{
|
||||
"kind": CloudSQLPostgresSourceKind,
|
||||
"project": CloudSQLPostgresProject,
|
||||
"instance": CloudSQLPostgresInstance,
|
||||
"region": CloudSQLPostgresRegion,
|
||||
"database": CloudSQLPostgresDatabase,
|
||||
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_POSTGRES_PROJECT,
|
||||
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
|
||||
"region": CLOUD_SQL_POSTGRES_REGION,
|
||||
"database": CLOUD_SQL_POSTGRES_DATABASE,
|
||||
"password": "random",
|
||||
}
|
||||
|
||||
noUserNoPassSourceConfig := map[string]any{
|
||||
"kind": CloudSQLPostgresSourceKind,
|
||||
"project": CloudSQLPostgresProject,
|
||||
"instance": CloudSQLPostgresInstance,
|
||||
"region": CloudSQLPostgresRegion,
|
||||
"database": CloudSQLPostgresDatabase,
|
||||
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
|
||||
"project": CLOUD_SQL_POSTGRES_PROJECT,
|
||||
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
|
||||
"region": CLOUD_SQL_POSTGRES_REGION,
|
||||
"database": CLOUD_SQL_POSTGRES_DATABASE,
|
||||
}
|
||||
tcs := []struct {
|
||||
name string
|
||||
@@ -240,7 +239,7 @@ func TestCloudSQLPgIAMConnection(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, CloudSQLPostgresToolKind)
|
||||
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
|
||||
if err != nil {
|
||||
if tc.isErr {
|
||||
return
|
||||
|
||||
163
tests/common.go
163
tests/common.go
@@ -28,7 +28,7 @@ import (
|
||||
)
|
||||
|
||||
// GetToolsConfig returns a mock tools config file
|
||||
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, paramToolStatement2, authToolStatement string) map[string]any {
|
||||
func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement, auth_tool_statement string) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
@@ -51,7 +51,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, p
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": paramToolStatement,
|
||||
"statement": param_tool_statement,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "id",
|
||||
@@ -65,25 +65,12 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, p
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": paramToolStatement2,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "user ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test authenticated parameters.",
|
||||
// statement to auto-fill authenticated parameter
|
||||
"statement": authToolStatement,
|
||||
"statement": auth_tool_statement,
|
||||
"parameters": []map[string]any{
|
||||
{
|
||||
"name": "email",
|
||||
@@ -250,8 +237,8 @@ func AddMySqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
|
||||
return config
|
||||
}
|
||||
|
||||
// AddMSSQLExecuteSqlConfig gets the tools config for `mssql-execute-sql`
|
||||
func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
// AddMssqlExecuteSqlConfig gets the tools config for `mssql-execute-sql`
|
||||
func AddMssqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
@@ -274,22 +261,21 @@ func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
|
||||
}
|
||||
|
||||
// GetPostgresSQLParamToolInfo returns statements and param for my-param-tool postgres-sql kind
|
||||
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES ($1), ($2), ($3), ($4);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = $1;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES ($1), ($2), ($3);", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetPostgresSQLAuthToolInfo returns statements and param of my-auth-tool for postgres-sql kind
|
||||
func GetPostgresSQLAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES ($1, $2), ($3, $4)", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
|
||||
params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES ($1, $2), ($3, $4)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
|
||||
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetPostgresSQLTmplToolStatement returns statements and param for template parameter test cases for postgres-sql kind
|
||||
@@ -299,63 +285,60 @@ func GetPostgresSQLTmplToolStatement() (string, string) {
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
// GetMSSQLParamToolInfo returns statements and param for my-param-tool mssql-sql kind
|
||||
func GetMSSQLParamToolInfo(tableName string) (string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (@alice), (@jane), (@sid), (@nil);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = @id;", tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid"), sql.Named("nil", nil)}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
// GetMssqlParamToolInfo returns statements and param for my-param-tool mssql-sql kind
|
||||
func GetMssqlParamToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (@alice), (@jane), (@sid);", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid")}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetMSSQLAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
|
||||
func GetMSSQLAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (@alice, @aliceemail), (@jane, @janeemail);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", ServiceAccountEmail), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
// GetMssqlAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
|
||||
func GetMssqlAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (@alice, @aliceemail), (@jane, @janeemail);", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", SERVICE_ACCOUNT_EMAIL), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetMSSQLTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
|
||||
func GetMSSQLTmplToolStatement() (string, string) {
|
||||
// GetMssqlTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
|
||||
func GetMssqlTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = @id"
|
||||
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = @name"
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
// GetMySQLParamToolInfo returns statements and param for my-param-tool mysql-sql kind
|
||||
func GetMySQLParamToolInfo(tableName string) (string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?), (?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
// GetMysqlParamToolInfo returns statements and param for my-param-tool mysql-sql kind
|
||||
func GetMysqlParamToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetMySQLAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
|
||||
func GetMySQLAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?, ?)", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
|
||||
params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
// GetMysqlAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
|
||||
func GetMysqlAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?, ?)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
|
||||
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetMySQLTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
|
||||
func GetMySQLTmplToolStatement() (string, string) {
|
||||
// GetMysqlTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
|
||||
func GetMysqlTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ?"
|
||||
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ?"
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
func GetNonSpannerInvokeParamWant() (string, string, string) {
|
||||
func GetNonSpannerInvokeParamWant() (string, string) {
|
||||
invokeParamWant := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := "[{\"id\":4,\"name\":null}]"
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`
|
||||
return invokeParamWant, invokeParamWantNull, mcpInvokeParamWant
|
||||
return invokeParamWant, mcpInvokeParamWant
|
||||
}
|
||||
|
||||
// GetPostgresWants return the expected wants for postgres
|
||||
@@ -366,16 +349,16 @@ func GetPostgresWants() (string, string, string) {
|
||||
return select1Want, failInvocationWant, createTableStatement
|
||||
}
|
||||
|
||||
// GetMSSQLWants return the expected wants for mssql
|
||||
func GetMSSQLWants() (string, string, string) {
|
||||
// GetMssqlWants return the expected wants for mssql
|
||||
func GetMssqlWants() (string, string, string) {
|
||||
select1Want := "[{\"\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(MAX))"`
|
||||
return select1Want, failInvocationWant, createTableStatement
|
||||
}
|
||||
|
||||
// GetMySQLWants return the expected wants for mysql
|
||||
func GetMySQLWants() (string, string, string) {
|
||||
// GetMysqlWants return the expected wants for mysql
|
||||
func GetMysqlWants() (string, string, string) {
|
||||
select1Want := "[{\"1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"`
|
||||
@@ -384,20 +367,20 @@ func GetMySQLWants() (string, string, string) {
|
||||
|
||||
// SetupPostgresSQLTable creates and inserts data into a table of tool
|
||||
// compatible with postgres-sql tool
|
||||
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
|
||||
err := pool.Ping(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = pool.Query(ctx, createStatement)
|
||||
_, err = pool.Query(ctx, create_statement)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = pool.Query(ctx, insertStatement, params...)
|
||||
_, err = pool.Query(ctx, insert_statement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
@@ -413,20 +396,20 @@ func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool
|
||||
|
||||
// SetupMsSQLTable creates and inserts data into a table of tool
|
||||
// compatible with mssql-sql tool
|
||||
func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||
func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
|
||||
err := pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = pool.QueryContext(ctx, createStatement)
|
||||
_, err = pool.QueryContext(ctx, create_statement)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = pool.QueryContext(ctx, insertStatement, params...)
|
||||
_, err = pool.QueryContext(ctx, insert_statement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
@@ -442,20 +425,20 @@ func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStat
|
||||
|
||||
// SetupMySQLTable creates and inserts data into a table of tool
|
||||
// compatible with mysql-sql tool
|
||||
func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
|
||||
func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
|
||||
err := pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to connect to test database: %s", err)
|
||||
}
|
||||
|
||||
// Create table
|
||||
_, err = pool.QueryContext(ctx, createStatement)
|
||||
_, err = pool.QueryContext(ctx, create_statement)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = pool.QueryContext(ctx, insertStatement, params...)
|
||||
_, err = pool.QueryContext(ctx, insert_statement, params...)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
@@ -470,13 +453,12 @@ func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStat
|
||||
}
|
||||
|
||||
// GetRedisWants return the expected wants for redis
|
||||
func GetRedisValkeyWants() (string, string, string, string, string) {
|
||||
func GetRedisValkeyWants() (string, string, string, string) {
|
||||
select1Want := "[\"PONG\"]"
|
||||
failInvocationWant := `unknown command 'SELEC 1;', with args beginning with: \""}]}}`
|
||||
invokeParamWant := "[{\"id\":\"1\",\"name\":\"Alice\"},{\"id\":\"3\",\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := `[{"id":"4","name":""}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
return select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant
|
||||
return select1Want, failInvocationWant, invokeParamWant, mcpInvokeParamWant
|
||||
}
|
||||
|
||||
func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any {
|
||||
@@ -515,19 +497,6 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"commands": [][]string{{"HGETALL", "row4"}},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"description": "user ID",
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
@@ -565,5 +534,7 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolsFile
|
||||
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package couchbase
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
|
||||
"github.com/couchbase/gocb/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
@@ -35,11 +34,12 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
couchbaseConnection = os.Getenv("COUCHBASE_CONNECTION")
|
||||
couchbaseBucket = os.Getenv("COUCHBASE_BUCKET")
|
||||
couchbaseScope = os.Getenv("COUCHBASE_SCOPE")
|
||||
couchbaseUser = os.Getenv("COUCHBASE_USER")
|
||||
couchbasePass = os.Getenv("COUCHBASE_PASS")
|
||||
couchbaseConnection = os.Getenv("COUCHBASE_CONNECTION")
|
||||
couchbaseBucket = os.Getenv("COUCHBASE_BUCKET")
|
||||
couchbaseScope = os.Getenv("COUCHBASE_SCOPE")
|
||||
couchbaseUser = os.Getenv("COUCHBASE_USER")
|
||||
couchbasePass = os.Getenv("COUCHBASE_PASS")
|
||||
SERVICE_ACCOUNT_EMAIL = os.Getenv("SERVICE_ACCOUNT_EMAIL")
|
||||
)
|
||||
|
||||
// getCouchbaseVars validates and returns Couchbase configuration variables
|
||||
@@ -103,7 +103,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
collectionNameTemplateParam := "template_param_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// Set up data for param tool
|
||||
paramToolStatement1, paramToolStatement2, params1 := getCouchbaseParamToolInfo(collectionNameParam)
|
||||
paramToolStatement, params1 := getCouchbaseParamToolInfo(collectionNameParam)
|
||||
teardownCollection1 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameParam, params1)
|
||||
defer teardownCollection1(t)
|
||||
|
||||
@@ -118,7 +118,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
defer teardownCollection3(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement, authToolStatement)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, couchbaseToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -129,7 +129,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -140,8 +140,8 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
select1Want := "[{\"$1\":1}]"
|
||||
failMcpInvocationWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to execute query: parsing failure | {\\\"statement\\\":\\\"SELEC 1;\\\""
|
||||
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failMcpInvocationWant)
|
||||
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
@@ -231,22 +231,18 @@ func setupCouchbaseCollection(t *testing.T, ctx context.Context, cluster *gocb.C
|
||||
}
|
||||
|
||||
// getCouchbaseParamToolInfo returns statements and params for my-param-tool couchbase-sql kind
|
||||
func getCouchbaseParamToolInfo(collectionName string) (string, string, []map[string]any) {
|
||||
func getCouchbaseParamToolInfo(collectionName string) (string, []map[string]any) {
|
||||
// N1QL uses positional or named parameters with $ prefix
|
||||
toolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE meta().id = TOSTRING($id) OR name = $name order by meta().id",
|
||||
collectionName, collectionName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE meta().id = TOSTRING($id) order by meta().id",
|
||||
collectionName, collectionName)
|
||||
|
||||
params := []map[string]any{
|
||||
{"name": "Alice"},
|
||||
{"name": "Jane"},
|
||||
{"name": "Sid"},
|
||||
{"name": nil},
|
||||
}
|
||||
return toolStatement, toolStatement2, params
|
||||
return toolStatement, params
|
||||
}
|
||||
|
||||
// getCouchbaseAuthToolInfo returns statements and param of my-auth-tool for couchbase-sql kind
|
||||
@@ -254,7 +250,7 @@ func getCouchbaseAuthToolInfo(collectionName string) (string, []map[string]any)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $email", collectionName)
|
||||
|
||||
params := []map[string]any{
|
||||
{"name": "Alice", "email": tests.ServiceAccountEmail},
|
||||
{"name": "Alice", "email": SERVICE_ACCOUNT_EMAIL},
|
||||
{"name": "Jane", "email": "janedoe@gmail.com"},
|
||||
}
|
||||
return toolStatement, params
|
||||
|
||||
@@ -26,24 +26,24 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
DgraphSourceKind = "dgraph"
|
||||
DgraphApiKey = "api-key"
|
||||
DgraphUrl = os.Getenv("DGRAPH_URL")
|
||||
DGRAPH_SOURCE_KIND = "dgraph"
|
||||
DGRAPH_TOOL_KIND = "dgraph-dql"
|
||||
DGRAPH_API_KEY = "api-key"
|
||||
DGRAPH_URL = os.Getenv("DGRAPH_URL")
|
||||
)
|
||||
|
||||
func getDgraphVars(t *testing.T) map[string]any {
|
||||
if DgraphUrl == "" {
|
||||
if DGRAPH_URL == "" {
|
||||
t.Fatal("'DGRAPH_URL' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": DgraphSourceKind,
|
||||
"dgraphUrl": DgraphUrl,
|
||||
"apiKey": DgraphApiKey,
|
||||
"kind": DGRAPH_SOURCE_KIND,
|
||||
"dgraphUrl": DGRAPH_URL,
|
||||
"apiKey": DGRAPH_API_KEY,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestDgraphToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
|
||||
@@ -28,14 +28,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
HttpSourceKind = "http"
|
||||
HttpToolKind = "http"
|
||||
HTTP_SOURCE_KIND = "http"
|
||||
HTTP_TOOL_KIND = "http"
|
||||
)
|
||||
|
||||
func getHTTPSourceConfig(t *testing.T) map[string]any {
|
||||
@@ -45,7 +44,7 @@ func getHTTPSourceConfig(t *testing.T) map[string]any {
|
||||
}
|
||||
idToken = "Bearer " + idToken
|
||||
return map[string]any{
|
||||
"kind": HttpSourceKind,
|
||||
"kind": HTTP_SOURCE_KIND,
|
||||
"headers": map[string]string{"Authorization": idToken},
|
||||
}
|
||||
}
|
||||
@@ -60,8 +59,6 @@ func multiTool(w http.ResponseWriter, r *http.Request) {
|
||||
handleTool0(w, r)
|
||||
case "tool1":
|
||||
handleTool1(w, r)
|
||||
case "tool1a":
|
||||
handleTool1a(w, r)
|
||||
case "tool2":
|
||||
handleTool2(w, r)
|
||||
case "tool3":
|
||||
@@ -133,27 +130,6 @@ func handleTool1(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool1a(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "4" {
|
||||
response := `[{"id":4,"name":null}]`
|
||||
_, err := w.Write([]byte(response))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool2(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
@@ -269,7 +245,7 @@ func TestHttpToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolKind)
|
||||
toolsFile := getHTTPToolsConfig(sourceConfig, HTTP_TOOL_KIND)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -278,16 +254,16 @@ func TestHttpToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
select1Want := `["Hello","World"]`
|
||||
invokeParamWant, invokeParamWantNull, _ := tests.GetNonSpannerInvokeParamWant()
|
||||
invokeParamWant, _ := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
runAdvancedHTTPInvokeTest(t)
|
||||
}
|
||||
|
||||
@@ -407,16 +383,6 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string
|
||||
"bodyParams": []tools.Parameter{tools.NewStringParameter("name", "user name")},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
"path": "/tool1a",
|
||||
"description": "some description",
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewIntParameter("id", "user ID")},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
|
||||
@@ -26,46 +26,45 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
MSSQLSourceKind = "mssql"
|
||||
MSSQLToolKind = "mssql-sql"
|
||||
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
|
||||
MSSQLHost = os.Getenv("MSSQL_HOST")
|
||||
MSSQLPort = os.Getenv("MSSQL_PORT")
|
||||
MSSQLUser = os.Getenv("MSSQL_USER")
|
||||
MSSQLPass = os.Getenv("MSSQL_PASS")
|
||||
MSSQL_SOURCE_KIND = "mssql"
|
||||
MSSQL_TOOL_KIND = "mssql-sql"
|
||||
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
|
||||
MSSQL_HOST = os.Getenv("MSSQL_HOST")
|
||||
MSSQL_PORT = os.Getenv("MSSQL_PORT")
|
||||
MSSQL_USER = os.Getenv("MSSQL_USER")
|
||||
MSSQL_PASS = os.Getenv("MSSQL_PASS")
|
||||
)
|
||||
|
||||
func getMsSQLVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case MSSQLDatabase:
|
||||
case MSSQL_DATABASE:
|
||||
t.Fatal("'MSSQL_DATABASE' not set")
|
||||
case MSSQLHost:
|
||||
case MSSQL_HOST:
|
||||
t.Fatal("'MSSQL_HOST' not set")
|
||||
case MSSQLPort:
|
||||
case MSSQL_PORT:
|
||||
t.Fatal("'MSSQL_PORT' not set")
|
||||
case MSSQLUser:
|
||||
case MSSQL_USER:
|
||||
t.Fatal("'MSSQL_USER' not set")
|
||||
case MSSQLPass:
|
||||
case MSSQL_PASS:
|
||||
t.Fatal("'MSSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": MSSQLSourceKind,
|
||||
"host": MSSQLHost,
|
||||
"port": MSSQLPort,
|
||||
"database": MSSQLDatabase,
|
||||
"user": MSSQLUser,
|
||||
"password": MSSQLPass,
|
||||
"kind": MSSQL_SOURCE_KIND,
|
||||
"host": MSSQL_HOST,
|
||||
"port": MSSQL_PORT,
|
||||
"database": MSSQL_DATABASE,
|
||||
"user": MSSQL_USER,
|
||||
"password": MSSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from mssql.go
|
||||
func initMSSQLConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
func initMssqlConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
query := url.Values{}
|
||||
query.Add("database", dbname)
|
||||
@@ -84,14 +83,14 @@ func initMSSQLConnection(host, port, user, pass, dbname string) (*sql.DB, error)
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
func TestMssqlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getMsSQLVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initMSSQLConnection(MSSQLHost, MSSQLPort, MSSQLUser, MSSQLPass, MSSQLDatabase)
|
||||
pool, err := initMssqlConnection(MSSQL_HOST, MSSQL_PORT, MSSQL_USER, MSSQL_PASS, MSSQL_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create SQL Server connection pool: %s", err)
|
||||
}
|
||||
@@ -102,20 +101,20 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMssqlParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetMSSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMsSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMssqlAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMsSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQLToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -125,7 +124,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -133,9 +132,9 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMssqlWants()
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -25,41 +25,40 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
MySQLSourceKind = "mysql"
|
||||
MySQLToolKind = "mysql-sql"
|
||||
MySQLDatabase = os.Getenv("MYSQL_DATABASE")
|
||||
MySQLHost = os.Getenv("MYSQL_HOST")
|
||||
MySQLPort = os.Getenv("MYSQL_PORT")
|
||||
MySQLUser = os.Getenv("MYSQL_USER")
|
||||
MySQLPass = os.Getenv("MYSQL_PASS")
|
||||
MYSQL_SOURCE_KIND = "mysql"
|
||||
MYSQL_TOOL_KIND = "mysql-sql"
|
||||
MYSQL_DATABASE = os.Getenv("MYSQL_DATABASE")
|
||||
MYSQL_HOST = os.Getenv("MYSQL_HOST")
|
||||
MYSQL_PORT = os.Getenv("MYSQL_PORT")
|
||||
MYSQL_USER = os.Getenv("MYSQL_USER")
|
||||
MYSQL_PASS = os.Getenv("MYSQL_PASS")
|
||||
)
|
||||
|
||||
func getMySQLVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case MySQLDatabase:
|
||||
case MYSQL_DATABASE:
|
||||
t.Fatal("'MYSQL_DATABASE' not set")
|
||||
case MySQLHost:
|
||||
case MYSQL_HOST:
|
||||
t.Fatal("'MYSQL_HOST' not set")
|
||||
case MySQLPort:
|
||||
case MYSQL_PORT:
|
||||
t.Fatal("'MYSQL_PORT' not set")
|
||||
case MySQLUser:
|
||||
case MYSQL_USER:
|
||||
t.Fatal("'MYSQL_USER' not set")
|
||||
case MySQLPass:
|
||||
case MYSQL_PASS:
|
||||
t.Fatal("'MYSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": MySQLSourceKind,
|
||||
"host": MySQLHost,
|
||||
"port": MySQLPort,
|
||||
"database": MySQLDatabase,
|
||||
"user": MySQLUser,
|
||||
"password": MySQLPass,
|
||||
"kind": MYSQL_SOURCE_KIND,
|
||||
"host": MYSQL_HOST,
|
||||
"port": MYSQL_PORT,
|
||||
"database": MYSQL_DATABASE,
|
||||
"user": MYSQL_USER,
|
||||
"password": MYSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,14 +74,14 @@ func initMySQLConnectionPool(host, port, user, pass, dbname string) (*sql.DB, er
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
func TestMySQLToolEndpoints(t *testing.T) {
|
||||
func TestMysqlToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getMySQLVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initMySQLConnectionPool(MySQLHost, MySQLPort, MySQLUser, MySQLPass, MySQLDatabase)
|
||||
pool, err := initMySQLConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASS, MYSQL_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create MySQL connection pool: %s", err)
|
||||
}
|
||||
@@ -93,20 +92,20 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetMySQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMysqlAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MySQLToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -116,7 +115,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -124,9 +123,9 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMysqlWants()
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -26,36 +26,36 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
Neo4jSourceKind = "neo4j"
|
||||
Neo4jDatabase = os.Getenv("NEO4J_DATABASE")
|
||||
Neo4jUri = os.Getenv("NEO4J_URI")
|
||||
Neo4jUser = os.Getenv("NEO4J_USER")
|
||||
Neo4jPass = os.Getenv("NEO4J_PASS")
|
||||
NEO4J_SOURCE_KIND = "neo4j"
|
||||
NEO4J_TOOL_KIND = "neo4j-cypher"
|
||||
NEO4J_DATABASE = os.Getenv("NEO4J_DATABASE")
|
||||
NEO4J_URI = os.Getenv("NEO4J_URI")
|
||||
NEO4J_USER = os.Getenv("NEO4J_USER")
|
||||
NEO4J_PASS = os.Getenv("NEO4J_PASS")
|
||||
)
|
||||
|
||||
func getNeo4jVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case Neo4jDatabase:
|
||||
case NEO4J_DATABASE:
|
||||
t.Fatal("'NEO4J_DATABASE' not set")
|
||||
case Neo4jUri:
|
||||
case NEO4J_URI:
|
||||
t.Fatal("'NEO4J_URI' not set")
|
||||
case Neo4jUser:
|
||||
case NEO4J_USER:
|
||||
t.Fatal("'NEO4J_USER' not set")
|
||||
case Neo4jPass:
|
||||
case NEO4J_PASS:
|
||||
t.Fatal("'NEO4J_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": Neo4jSourceKind,
|
||||
"uri": Neo4jUri,
|
||||
"database": Neo4jDatabase,
|
||||
"user": Neo4jUser,
|
||||
"password": Neo4jPass,
|
||||
"kind": NEO4J_SOURCE_KIND,
|
||||
"uri": NEO4J_URI,
|
||||
"database": NEO4J_DATABASE,
|
||||
"user": NEO4J_USER,
|
||||
"password": NEO4J_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
|
||||
@@ -25,42 +25,41 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var (
|
||||
PostgresSourceKind = "postgres"
|
||||
PostgresToolKind = "postgres-sql"
|
||||
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
|
||||
PostgresHost = os.Getenv("POSTGRES_HOST")
|
||||
PostgresPort = os.Getenv("POSTGRES_PORT")
|
||||
PostgresUser = os.Getenv("POSTGRES_USER")
|
||||
PostgresPass = os.Getenv("POSTGRES_PASS")
|
||||
POSTGRES_SOURCE_KIND = "postgres"
|
||||
POSTGRES_TOOL_KIND = "postgres-sql"
|
||||
POSTGRES_DATABASE = os.Getenv("POSTGRES_DATABASE")
|
||||
POSTGRES_HOST = os.Getenv("POSTGRES_HOST")
|
||||
POSTGRES_PORT = os.Getenv("POSTGRES_PORT")
|
||||
POSTGRES_USER = os.Getenv("POSTGRES_USER")
|
||||
POSTGRES_PASS = os.Getenv("POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func getPostgresVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case PostgresDatabase:
|
||||
case POSTGRES_DATABASE:
|
||||
t.Fatal("'POSTGRES_DATABASE' not set")
|
||||
case PostgresHost:
|
||||
case POSTGRES_HOST:
|
||||
t.Fatal("'POSTGRES_HOST' not set")
|
||||
case PostgresPort:
|
||||
case POSTGRES_PORT:
|
||||
t.Fatal("'POSTGRES_PORT' not set")
|
||||
case PostgresUser:
|
||||
case POSTGRES_USER:
|
||||
t.Fatal("'POSTGRES_USER' not set")
|
||||
case PostgresPass:
|
||||
case POSTGRES_PASS:
|
||||
t.Fatal("'POSTGRES_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": PostgresSourceKind,
|
||||
"host": PostgresHost,
|
||||
"port": PostgresPort,
|
||||
"database": PostgresDatabase,
|
||||
"user": PostgresUser,
|
||||
"password": PostgresPass,
|
||||
"kind": POSTGRES_SOURCE_KIND,
|
||||
"host": POSTGRES_HOST,
|
||||
"port": POSTGRES_PORT,
|
||||
"database": POSTGRES_DATABASE,
|
||||
"user": POSTGRES_USER,
|
||||
"password": POSTGRES_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,7 +87,7 @@ func TestPostgres(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initPostgresConnectionPool(PostgresHost, PostgresPort, PostgresUser, PostgresPass, PostgresDatabase)
|
||||
pool, err := initPostgresConnectionPool(POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASS, POSTGRES_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create postgres connection pool: %s", err)
|
||||
}
|
||||
@@ -99,20 +98,20 @@ func TestPostgres(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -122,7 +121,7 @@ func TestPostgres(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -131,8 +130,8 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -22,29 +22,28 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
RedisSourceKind = "redis"
|
||||
RedisToolKind = "redis"
|
||||
RedisAddress = os.Getenv("REDIS_ADDRESS")
|
||||
RedisPass = os.Getenv("REDIS_PASS")
|
||||
REDIS_SOURCE_KIND = "redis"
|
||||
REDIS_TOOL_KIND = "redis"
|
||||
REDIS_ADDRESS = os.Getenv("REDIS_ADDRESS")
|
||||
REDIS_PASS = os.Getenv("REDIS_PASS")
|
||||
)
|
||||
|
||||
func getRedisVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case RedisAddress:
|
||||
case REDIS_ADDRESS:
|
||||
t.Fatal("'REDIS_ADDRESS' not set")
|
||||
case RedisPass:
|
||||
case REDIS_PASS:
|
||||
t.Fatal("'REDIS_PASS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": RedisSourceKind,
|
||||
"address": []string{RedisAddress},
|
||||
"password": RedisPass,
|
||||
"kind": REDIS_SOURCE_KIND,
|
||||
"address": []string{REDIS_ADDRESS},
|
||||
"password": REDIS_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +70,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
client, err := initRedisClient(ctx, RedisAddress, RedisPass)
|
||||
client, err := initRedisClient(ctx, REDIS_ADDRESS, REDIS_PASS)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Redis connection: %s", err)
|
||||
}
|
||||
@@ -81,7 +80,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
defer teardownDB(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, RedisToolKind)
|
||||
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, REDIS_TOOL_KIND)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -91,7 +90,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -99,19 +98,18 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, invokeParamWant, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
}
|
||||
|
||||
func setupRedisDB(t *testing.T, ctx context.Context, client *redis.Client) func(*testing.T) {
|
||||
keys := []string{"row1", "row2", "row3", "row4"}
|
||||
keys := []string{"row1", "row2", "row3"}
|
||||
commands := [][]any{
|
||||
{"HSET", keys[0], "id", 1, "name", "Alice"},
|
||||
{"HSET", keys[1], "id", 2, "name", "Jane"},
|
||||
{"HSET", keys[2], "id", 3, "name", "Sid"},
|
||||
{"HSET", keys[3], "id", 4, "name", nil},
|
||||
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
|
||||
{"HSET", tests.SERVICE_ACCOUNT_EMAIL, "name", "Alice"},
|
||||
}
|
||||
for _, c := range commands {
|
||||
resp := client.Do(ctx, c...)
|
||||
|
||||
@@ -15,10 +15,13 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
|
||||
@@ -130,3 +133,63 @@ func (c *CmdExec) Close() {
|
||||
c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForString waits until the server logs a single line that matches the provided regex.
|
||||
// returns the output of whatever the server sent so far.
|
||||
func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string, error) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
in := bufio.NewReader(c.Out)
|
||||
|
||||
// read lines in background, sending result of each read over a channel
|
||||
// this allows us to use in.ReadString without blocking
|
||||
type result struct {
|
||||
s string
|
||||
err error
|
||||
}
|
||||
output := make(chan result)
|
||||
go func() {
|
||||
defer close(output)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// if the context is canceled, the orig thread will send back the error
|
||||
// so we can just exit the goroutine here
|
||||
return
|
||||
default:
|
||||
// otherwise read a line from the output
|
||||
s, err := in.ReadString('\n')
|
||||
if err != nil {
|
||||
output <- result{err: err}
|
||||
return
|
||||
}
|
||||
output <- result{s: s}
|
||||
// if that last string matched, exit the goroutine
|
||||
if re.MatchString(s) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// collect the output until the ctx is canceled, an error was hit,
|
||||
// or match was found (which is indicated the channel is closed)
|
||||
var sb strings.Builder
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// if ctx is done, return that error
|
||||
return sb.String(), ctx.Err()
|
||||
case o, ok := <-output:
|
||||
if !ok {
|
||||
// match was found!
|
||||
return sb.String(), nil
|
||||
}
|
||||
if o.err != nil {
|
||||
// error was found!
|
||||
return sb.String(), o.err
|
||||
}
|
||||
sb.WriteString(o.s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,6 @@ import (
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
// RunSourceConnection test for source connection
|
||||
@@ -58,7 +57,7 @@ func RunSourceConnectionTest(t *testing.T, sourceConfig map[string]any, toolKind
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
return fmt.Errorf("toolbox didn't start successfully: %s", err)
|
||||
|
||||
@@ -31,34 +31,33 @@ import (
|
||||
database "cloud.google.com/go/spanner/admin/database/apiv1"
|
||||
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
SpannerSourceKind = "spanner"
|
||||
SpannerToolKind = "spanner-sql"
|
||||
SpannerProject = os.Getenv("SPANNER_PROJECT")
|
||||
SpannerDatabase = os.Getenv("SPANNER_DATABASE")
|
||||
SpannerInstance = os.Getenv("SPANNER_INSTANCE")
|
||||
SPANNER_SOURCE_KIND = "spanner"
|
||||
SPANNER_TOOL_KIND = "spanner-sql"
|
||||
SPANNER_PROJECT = os.Getenv("SPANNER_PROJECT")
|
||||
SPANNER_DATABASE = os.Getenv("SPANNER_DATABASE")
|
||||
SPANNER_INSTANCE = os.Getenv("SPANNER_INSTANCE")
|
||||
)
|
||||
|
||||
func getSpannerVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case SpannerProject:
|
||||
case SPANNER_PROJECT:
|
||||
t.Fatal("'SPANNER_PROJECT' not set")
|
||||
case SpannerDatabase:
|
||||
case SPANNER_DATABASE:
|
||||
t.Fatal("'SPANNER_DATABASE' not set")
|
||||
case SpannerInstance:
|
||||
case SPANNER_INSTANCE:
|
||||
t.Fatal("'SPANNER_INSTANCE' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": SpannerSourceKind,
|
||||
"project": SpannerProject,
|
||||
"instance": SpannerInstance,
|
||||
"database": SpannerDatabase,
|
||||
"kind": SPANNER_SOURCE_KIND,
|
||||
"project": SPANNER_PROJECT,
|
||||
"instance": SPANNER_INSTANCE,
|
||||
"database": SPANNER_DATABASE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,13 +90,13 @@ func initSpannerClients(ctx context.Context, project, instance, dbname string) (
|
||||
|
||||
func TestSpannerToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getSpannerVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Create Spanner client
|
||||
dataClient, adminClient, err := initSpannerClients(ctx, SpannerProject, SpannerInstance, SpannerDatabase)
|
||||
dataClient, adminClient, err := initSpannerClients(ctx, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Spanner client: %s", err)
|
||||
}
|
||||
@@ -108,19 +107,19 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := getSpannerParamToolInfo(tableNameParam)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := getSpannerParamToolInfo(tableNameParam)
|
||||
dbString := fmt.Sprintf(
|
||||
"projects/%s/instances/%s/databases/%s",
|
||||
SpannerProject,
|
||||
SpannerInstance,
|
||||
SpannerDatabase,
|
||||
SPANNER_PROJECT,
|
||||
SPANNER_INSTANCE,
|
||||
SPANNER_DATABASE,
|
||||
)
|
||||
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, createStatement1, insertStatement1, tableNameParam, dbString, params1)
|
||||
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, create_statement1, insert_statement1, tableNameParam, dbString, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := getSpannerAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, createStatement2, insertStatement2, tableNameAuth, dbString, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := getSpannerAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, create_statement2, insert_statement2, tableNameAuth, dbString, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// set up data for template param tool
|
||||
@@ -129,7 +128,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
defer teardownTableTmpl(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SpannerToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SPANNER_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
toolsFile = addSpannerExecuteSqlConfig(t, toolsFile)
|
||||
toolsFile = addSpannerReadOnlyConfig(t, toolsFile)
|
||||
toolsFile = addTemplateParamConfig(t, toolsFile)
|
||||
@@ -142,7 +141,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -153,11 +152,10 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
select1Want := "[{\"\":\"1\"}]"
|
||||
accessSchemaWant := "[{\"schema_name\":\"INFORMATION_SCHEMA\"}]"
|
||||
invokeParamWant := "[{\"id\":\"1\",\"name\":\"Alice\"},{\"id\":\"3\",\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := `[{"id":"4","name":null}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
failInvocationWant := `"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute client: unable to parse row: spanner: code = \"InvalidArgument\", desc = \"Syntax error: Unexpected identifier \\\\\\\"SELEC\\\\\\\" [at 1:1]\\\\nSELEC 1;\\\\n^\"`
|
||||
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
runSpannerSchemaToolInvokeTest(t, accessSchemaWant)
|
||||
runSpannerExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, tableNameAuth)
|
||||
@@ -171,27 +169,26 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
}
|
||||
|
||||
// getSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
|
||||
func getSpannerParamToolInfo(tableName string) (string, string, string, string, map[string]any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3), (4, @name4)", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = @id", tableName)
|
||||
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid", "name4": nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
func getSpannerParamToolInfo(tableName string) (string, string, string, map[string]any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
|
||||
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// getSpannerAuthToolInfo returns statements and param of my-auth-tool for spanner-sql kind
|
||||
func getSpannerAuthToolInfo(tableName string) (string, string, string, map[string]any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
|
||||
params := map[string]any{
|
||||
"name1": "Alice",
|
||||
"email1": tests.ServiceAccountEmail,
|
||||
"email1": tests.SERVICE_ACCOUNT_EMAIL,
|
||||
"name2": "Jane",
|
||||
"email2": "janedoe@gmail.com",
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// setupSpannerTable creates and inserts data into a table of tool
|
||||
@@ -233,13 +230,13 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
|
||||
Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("unable to start drop %s operation: %s", tableName, err)
|
||||
t.Errorf("unable to start drop table operation: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
opErr := op.Wait(ctx)
|
||||
if opErr != nil {
|
||||
t.Errorf("Teardown failed: %s", opErr)
|
||||
err = op.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -355,7 +352,7 @@ func addTemplateParamConfig(t *testing.T, config map[string]any) map[string]any
|
||||
return config
|
||||
}
|
||||
|
||||
func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam, tableNameAuth string) {
|
||||
func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select_1_want, invokeParamWant, tableNameParam, tableNameAuth string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
@@ -376,7 +373,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool-read-only/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -420,7 +417,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -441,7 +438,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
|
||||
name: "invoke my-exec-sql-tool insert entry",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"INSERT INTO %s (id, name) VALUES (5, 'test_name')\"}", tableNameParam))),
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"INSERT INTO %s (id, name) VALUES (4, 'test_name')\"}", tableNameParam))),
|
||||
want: "null",
|
||||
isErr: false,
|
||||
},
|
||||
@@ -458,7 +455,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
isErr: false,
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with invalid auth token",
|
||||
|
||||
@@ -25,20 +25,19 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
SQLiteSourceKind = "sqlite"
|
||||
SQLiteToolKind = "sqlite-sql"
|
||||
SQLiteDatabase = os.Getenv("SQLITE_DATABASE")
|
||||
SQLITE_SOURCE_KIND = "sqlite"
|
||||
SQLITE_TOOL_KIND = "sqlite-sql"
|
||||
SQLITE_DATABASE = os.Getenv("SQLITE_DATABASE")
|
||||
)
|
||||
|
||||
func getSQLiteVars(t *testing.T) map[string]any {
|
||||
return map[string]any{
|
||||
"kind": SQLiteSourceKind,
|
||||
"database": SQLiteDatabase,
|
||||
"kind": SQLITE_SOURCE_KIND,
|
||||
"database": SQLITE_DATABASE,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,21 +80,20 @@ func setupSQLiteTestDB(t *testing.T, ctx context.Context, db *sql.DB, createStat
|
||||
}
|
||||
}
|
||||
|
||||
func getSQLiteParamToolInfo(tableName string) (string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?), (?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, params
|
||||
func getSQLiteParamToolInfo(tableName string) (string, string, string, []any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT);", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
func getSQLiteAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT)", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?,?) RETURNING id, name, email;", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
|
||||
params := []any{"Alice", tests.ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
create_statement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT)", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?,?) RETURNING id, name, email;", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
|
||||
params := []any{"Alice", tests.SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
func getSQLiteTmplToolStatement() (string, string) {
|
||||
@@ -105,7 +103,7 @@ func getSQLiteTmplToolStatement() (string, string) {
|
||||
}
|
||||
|
||||
func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
db, teardownDb, sqliteDb, err := initSQLiteDb(t, SQLiteDatabase)
|
||||
db, teardownDb, sqliteDb, err := initSQLiteDb(t, SQLITE_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -125,17 +123,17 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createStatement1, insertStatement1, paramToolStatement1, paramToolStatement2, params1 := getSQLiteParamToolInfo(tableNameParam)
|
||||
setupSQLiteTestDB(t, ctx, db, createStatement1, insertStatement1, tableNameParam, params1)
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := getSQLiteParamToolInfo(tableNameParam)
|
||||
setupSQLiteTestDB(t, ctx, db, create_statement1, insert_statement1, tableNameParam, params1)
|
||||
|
||||
// set up data for auth tool
|
||||
createStatement2, insertStatement2, authToolStatement, params2 := getSQLiteAuthToolInfo(tableNameAuth)
|
||||
setupSQLiteTestDB(t, ctx, db, createStatement2, insertStatement2, tableNameAuth, params2)
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := getSQLiteAuthToolInfo(tableNameAuth)
|
||||
setupSQLiteTestDB(t, ctx, db, create_statement2, insert_statement2, tableNameAuth, params2)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SQLiteToolKind, paramToolStatement1, paramToolStatement2, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SQLITE_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getSQLiteTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLiteToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLITE_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -145,7 +143,7 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -155,8 +153,8 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
|
||||
select1Want := "[{\"1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: SQL logic error: near \"SELEC\": syntax error (1)"}],"isError":true}}`
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func RunToolGetTest(t *testing.T) {
|
||||
}
|
||||
|
||||
// RunToolInvoke runs the tool invoke endpoint
|
||||
func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeParamWantNull string) {
|
||||
func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant string) {
|
||||
// Get ID token
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
@@ -109,23 +109,15 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeParamWa
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool2 with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool2/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
|
||||
want: invokeParamWantNull,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-param-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-param-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
name: "Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
||||
isErr: true,
|
||||
@@ -436,7 +428,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
}
|
||||
}
|
||||
|
||||
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select1Want string) {
|
||||
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select_1_want string) {
|
||||
// Get ID token
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
@@ -457,7 +449,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, sele
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -497,7 +489,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, sele
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
|
||||
isErr: false,
|
||||
want: select1Want,
|
||||
want: select_1_want,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-exec-sql-tool with invalid auth token",
|
||||
@@ -605,7 +597,7 @@ func RunInitialize(t *testing.T, protocolVersion string) string {
|
||||
}
|
||||
|
||||
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
|
||||
func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant string) {
|
||||
func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want string) {
|
||||
sessionId := RunInitialize(t, "2024-11-05")
|
||||
header := map[string]string{}
|
||||
if sessionId != "" {
|
||||
@@ -723,7 +715,7 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
want: failInvocationWant,
|
||||
want: fail_invocation_want,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
@@ -22,25 +22,24 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
var (
|
||||
ValkeySourceKind = "valkey"
|
||||
ValkeyToolKind = "valkey"
|
||||
ValkeyAddress = os.Getenv("VALKEY_ADDRESS")
|
||||
VALKEY_SOURCE_KIND = "valkey"
|
||||
VALKEY_TOOL_KIND = "valkey"
|
||||
VALKEY_ADDRESS = os.Getenv("VALKEY_ADDRESS")
|
||||
)
|
||||
|
||||
func getValkeyVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case ValkeyAddress:
|
||||
case VALKEY_ADDRESS:
|
||||
t.Fatal("'VALKEY_ADDRESS' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": ValkeySourceKind,
|
||||
"address": []string{ValkeyAddress},
|
||||
"kind": VALKEY_SOURCE_KIND,
|
||||
"address": []string{VALKEY_ADDRESS},
|
||||
"disableCache": true,
|
||||
}
|
||||
}
|
||||
@@ -74,7 +73,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
client, err := initValkeyClient(ctx, []string{ValkeyAddress})
|
||||
client, err := initValkeyClient(ctx, []string{VALKEY_ADDRESS})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Valkey connection: %s", err)
|
||||
}
|
||||
@@ -84,7 +83,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
defer teardownDB(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, ValkeyToolKind)
|
||||
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, VALKEY_TOOL_KIND)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
@@ -94,7 +93,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
@@ -102,19 +101,18 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull)
|
||||
select1Want, failInvocationWant, invokeParamWant, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
}
|
||||
|
||||
func setupValkeyDB(t *testing.T, ctx context.Context, client valkey.Client) func(*testing.T) {
|
||||
keys := []string{"row1", "row2", "row3", "row4"}
|
||||
keys := []string{"row1", "row2", "row3"}
|
||||
commands := [][]string{
|
||||
{"HSET", keys[0], "name", "Alice", "id", "1"},
|
||||
{"HSET", keys[1], "name", "Jane", "id", "2"},
|
||||
{"HSET", keys[2], "name", "Sid", "id", "3"},
|
||||
{"HSET", keys[3], "name", "", "id", "4"},
|
||||
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
|
||||
{"HSET", tests.SERVICE_ACCOUNT_EMAIL, "name", "Alice"},
|
||||
}
|
||||
builtCmds := make(valkey.Commands, len(commands))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user