Compare commits

..

1 Commits

Author SHA1 Message Date
Yuan Teoh
4416b14ee6 fix: optional parameter int type conversion 2025-07-01 11:27:23 -06:00
37 changed files with 604 additions and 831 deletions

View File

@@ -51,7 +51,6 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod
- name: Setup Hugo

View File

@@ -1,37 +1,5 @@
# Changelog
## [0.8.0](https://github.com/googleapis/genai-toolbox/compare/v0.7.0...v0.8.0) (2025-07-02)
### ⚠ BREAKING CHANGES
* **postgres,mssql,cloudsqlmssql:** encode source connection url for sources ([#727](https://github.com/googleapis/genai-toolbox/issues/727))
### Features
* Add support for multiple YAML configuration files ([#760](https://github.com/googleapis/genai-toolbox/issues/760)) ([40679d7](https://github.com/googleapis/genai-toolbox/commit/40679d700eded50d19569923e2a71c51e907a8bf))
* Add support for optional parameters ([#617](https://github.com/googleapis/genai-toolbox/issues/617)) ([4827771](https://github.com/googleapis/genai-toolbox/commit/4827771b78dee9a1284a898b749509b472061527)), closes [#475](https://github.com/googleapis/genai-toolbox/issues/475)
* **mcp:** Support MCP version 2025-03-26 ([#755](https://github.com/googleapis/genai-toolbox/issues/755)) ([474df57](https://github.com/googleapis/genai-toolbox/commit/474df57d62de683079f8d12c31db53396a545fd1))
* **sources/http:** Support disable SSL verification for HTTP Source ([#674](https://github.com/googleapis/genai-toolbox/issues/674)) ([4055b0c](https://github.com/googleapis/genai-toolbox/commit/4055b0c3569c527560d7ad34262963b3dd4e282d))
* **tools/bigquery:** Add templateParameters field for bigquery ([#699](https://github.com/googleapis/genai-toolbox/issues/699)) ([f5f771b](https://github.com/googleapis/genai-toolbox/commit/f5f771b0f3d159630ff602ff55c6c66b61981446))
* **tools/bigtable:** Add templateParameters field for bigtable ([#692](https://github.com/googleapis/genai-toolbox/issues/692)) ([1c06771](https://github.com/googleapis/genai-toolbox/commit/1c067715fac06479eb0060d7067b73dba099ed92))
* **tools/couchbase:** Add templateParameters field for couchbase ([#723](https://github.com/googleapis/genai-toolbox/issues/723)) ([9197186](https://github.com/googleapis/genai-toolbox/commit/9197186b8bea1ac4ec1b39c9c5c110807c8b2ba9))
* **tools/http:** Add support for HTTP Tool pathParams ([#726](https://github.com/googleapis/genai-toolbox/issues/726)) ([fd300dc](https://github.com/googleapis/genai-toolbox/commit/fd300dc606d88bf9f7bba689e2cee4e3565537dd))
* **tools/redis:** Add Redis Source and Tool ([#519](https://github.com/googleapis/genai-toolbox/issues/519)) ([f0aef29](https://github.com/googleapis/genai-toolbox/commit/f0aef29b0c2563e2a00277fbe2784f39f16d2835))
* **tools/spanner:** Add templateParameters field for spanner ([#691](https://github.com/googleapis/genai-toolbox/issues/691)) ([075dfa4](https://github.com/googleapis/genai-toolbox/commit/075dfa47e1fd92be4847bd0aec63296146b66455))
* **tools/sqlitesql:** Add templateParameters field for sqlitesql ([#687](https://github.com/googleapis/genai-toolbox/issues/687)) ([75e254c](https://github.com/googleapis/genai-toolbox/commit/75e254c0a4ce690ca5fa4d1741550ce54734b226))
* **tools/valkey:** Add Valkey Source and Tool ([#532](https://github.com/googleapis/genai-toolbox/issues/532)) ([054ec19](https://github.com/googleapis/genai-toolbox/commit/054ec198b97ba9f36f67dd12b2eff0cc6bc4d080))
### Bug Fixes
* **bigquery,mssql:** Fix panic on tools with array param ([#722](https://github.com/googleapis/genai-toolbox/issues/722)) ([7a6644c](https://github.com/googleapis/genai-toolbox/commit/7a6644cf0c5413e5c803955c88a2cfd0a2233ed3))
* **postgres,mssql,cloudsqlmssql:** Encode source connection url for sources ([#727](https://github.com/googleapis/genai-toolbox/issues/727)) ([67964d9](https://github.com/googleapis/genai-toolbox/commit/67964d939f27320b63b5759f4b3f3fdaa0c76fbf)), closes [#717](https://github.com/googleapis/genai-toolbox/issues/717)
* Set default value to field's type during unmarshalling ([#774](https://github.com/googleapis/genai-toolbox/issues/774)) ([fafed24](https://github.com/googleapis/genai-toolbox/commit/fafed2485839cf1acc1350e8a24103d2e6356ee0)), closes [#771](https://github.com/googleapis/genai-toolbox/issues/771)
* **server/mcp:** Do not listen from port for stdio ([#719](https://github.com/googleapis/genai-toolbox/issues/719)) ([d51dbc7](https://github.com/googleapis/genai-toolbox/commit/d51dbc759ba493021d3ec6f5417fc04c21f7044f)), closes [#711](https://github.com/googleapis/genai-toolbox/issues/711)
* **tools/mysqlexecutesql:** Handle nil panic and connection leak in Invoke ([#757](https://github.com/googleapis/genai-toolbox/issues/757)) ([7badba4](https://github.com/googleapis/genai-toolbox/commit/7badba42eefb34252be77b852a57d6bd78dd267d))
* **tools/mysqlsql:** Handle nil panic and connection leak in invoke ([#758](https://github.com/googleapis/genai-toolbox/issues/758)) ([cbb4a33](https://github.com/googleapis/genai-toolbox/commit/cbb4a333517313744800d148840312e56340f3fd))
## [0.7.0](https://github.com/googleapis/genai-toolbox/compare/v0.6.0...v0.7.0) (2025-06-10)

159
README.md
View File

@@ -111,7 +111,7 @@ To install Toolbox as a binary:
<!-- {x-release-please-start-version} -->
```sh
# see releases page for other versions
export VERSION=0.8.0
export VERSION=0.7.0
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
chmod +x toolbox
```
@@ -124,7 +124,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.8.0
export VERSION=0.7.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -137,7 +137,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.8.0
go install github.com/googleapis/genai-toolbox@v0.7.0
```
<!-- {x-release-please-end} -->
@@ -165,12 +165,7 @@ Once your server is up and running, you can load the tools into your
application. See below the list of Client SDKs for using various frameworks:
<details open>
<summary>Python</summary>
<br>
<blockquote>
<details open>
<summary>Core</summary>
<summary>Core</summary>
1. Install [Toolbox Core SDK][toolbox-core]:
@@ -196,9 +191,9 @@ For more detailed instructions on using the Toolbox Core SDK, see the
[toolbox-core]: https://pypi.org/project/toolbox-core/
[toolbox-core-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/tree/main/packages/toolbox-core/README.md
</details>
<details>
<summary>LangChain / LangGraph</summary>
</details>
<details>
<summary>LangChain / LangGraph</summary>
1. Install [Toolbox LangChain SDK][toolbox-langchain]:
@@ -218,15 +213,16 @@ For more detailed instructions on using the Toolbox Core SDK, see the
tools = client.load_toolset()
```
For more detailed instructions on using the Toolbox LangChain SDK, see the
[project's README][toolbox-langchain-readme].
For more detailed instructions on using the Toolbox LangChain SDK, see the
[project's README][toolbox-langchain-readme].
[toolbox-langchain]: https://pypi.org/project/toolbox-langchain/
[toolbox-langchain-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-langchain/README.md
[toolbox-langchain]: https://pypi.org/project/toolbox-langchain/
[toolbox-langchain-readme]: https://github.com/googleapis/mcp-toolbox-sdk-python/blob/main/packages/toolbox-langchain/README.md
</details>
<details>
<summary>LlamaIndex</summary>
</details>
<details>
<summary>LlamaIndex</summary>
1. Install [Toolbox Llamaindex SDK][toolbox-llamaindex]:
@@ -246,129 +242,12 @@ For more detailed instructions on using the Toolbox Core SDK, see the
tools = client.load_toolset()
```
For more detailed instructions on using the Toolbox Llamaindex SDK, see the
[project's README][toolbox-llamaindex-readme].
For more detailed instructions on using the Toolbox Llamaindex SDK, see the
[project's README][toolbox-llamaindex-readme].
[toolbox-llamaindex]: https://pypi.org/project/toolbox-llamaindex/
[toolbox-llamaindex-readme]: https://github.com/googleapis/genai-toolbox-llamaindex-python/blob/main/README.md
[toolbox-llamaindex]: https://pypi.org/project/toolbox-llamaindex/
[toolbox-llamaindex-readme]: https://github.com/googleapis/genai-toolbox-llamaindex-python/blob/main/README.md
</details>
</details>
</blockquote>
<details>
<summary>Javascript/Typescript</summary>
<br>
<blockquote>
<details open>
<summary>Core</summary>
1. Install [Toolbox Core SDK][toolbox-core-js]:
```bash
npm install @toolbox-sdk/core
```
1. Load tools:
```javascript
import { ToolboxClient } from '@toolbox-sdk/core';
// update the url to point to your server
const URL = 'http://127.0.0.1:5000';
let client = new ToolboxClient(URL);
// these tools can be passed to your application!
const tools = await client.loadToolset('toolsetName');
```
For more detailed instructions on using the Toolbox Core SDK, see the
[project's README][toolbox-core-js-readme].
[toolbox-core-js]: https://www.npmjs.com/package/@toolbox-sdk/core
[toolbox-core-js-readme]: https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md
</details>
<details>
<summary>LangChain / LangGraph</summary>
1. Install [Toolbox Core SDK][toolbox-core-js]:
```bash
npm install @toolbox-sdk/core
```
2. Load tools:
```javascript
import { ToolboxClient } from '@toolbox-sdk/core';
// update the url to point to your server
const URL = 'http://127.0.0.1:5000';
let client = new ToolboxClient(URL);
// these tools can be passed to your application!
const toolboxTools = await client.loadToolset('toolsetName');
// Define the basics of the tool: name, description, schema and core logic
const getTool = (toolboxTool) => tool(currTool, {
name: toolboxTool.getName(),
description: toolboxTool.getDescription(),
schema: toolboxTool.getParamSchema()
});
// Use these tools in your Langchain/Langraph applications
const tools = toolboxTools.map(getTool);
```
</details>
<details>
<summary>Genkit</summary>
1. Install [Toolbox Core SDK][toolbox-core-js]:
```bash
npm install @toolbox-sdk/core
```
2. Load tools:
```javascript
import { ToolboxClient } from '@toolbox-sdk/core';
import { genkit } from 'genkit';
// Initialise genkit
const ai = genkit({
plugins: [
googleAI({
apiKey: process.env.GEMINI_API_KEY || process.env.GOOGLE_API_KEY
})
],
model: googleAI.model('gemini-2.0-flash'),
});
// update the url to point to your server
const URL = 'http://127.0.0.1:5000';
let client = new ToolboxClient(URL);
// these tools can be passed to your application!
const toolboxTools = await client.loadToolset('toolsetName');
// Define the basics of the tool: name, description, schema and core logic
const getTool = (toolboxTool) => ai.defineTool({
name: toolboxTool.getName(),
description: toolboxTool.getDescription(),
schema: toolboxTool.getParamSchema()
}, toolboxTool)
// Use these tools in your Genkit applications
const tools = toolboxTools.map(getTool);
```
</details>
</details>
</blockquote>
</details>
## Configuration

View File

@@ -1 +1 @@
0.8.0
0.7.0

View File

@@ -222,7 +222,7 @@
},
"outputs": [],
"source": [
"version = \"0.8.0\" # x-release-please-version\n",
"version = \"0.7.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -86,7 +86,7 @@ To install Toolbox as a binary:
```sh
# see releases page for other versions
export VERSION=0.8.0
export VERSION=0.7.0
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
chmod +x toolbox
```
@@ -97,7 +97,7 @@ You can also install Toolbox as a container:
```sh
# see releases page for other versions
export VERSION=0.8.0
export VERSION=0.7.0
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
```
@@ -108,7 +108,7 @@ To install from source, ensure you have the latest version of
[Go installed](https://go.dev/doc/install), and then run the following command:
```sh
go install github.com/googleapis/genai-toolbox@v0.8.0
go install github.com/googleapis/genai-toolbox@v0.7.0
```
{{% /tab %}}

View File

@@ -156,7 +156,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -52,19 +52,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview).
<!-- {x-release-please-start-version} -->
{{< tabpane persist=header >}}
{{< tab header="linux/amd64" lang="bash" >}}
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/linux/amd64/toolbox>
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/linux/amd64/toolbox>
{{< /tab >}}
{{< tab header="darwin/arm64" lang="bash" >}}
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/darwin/arm64/toolbox>
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/darwin/arm64/toolbox>
{{< /tab >}}
{{< tab header="darwin/amd64" lang="bash" >}}
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/darwin/amd64/toolbox>
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/darwin/amd64/toolbox>
{{< /tab >}}
{{< tab header="windows/amd64" lang="bash" >}}
curl -O <https://storage.googleapis.com/genai-toolbox/v0.8.0/windows/amd64/toolbox>
curl -O <https://storage.googleapis.com/genai-toolbox/v0.7.0/windows/amd64/toolbox>
{{< /tab >}}
{{< /tabpane >}}
<!-- {x-release-please-end} -->

View File

@@ -220,7 +220,7 @@
},
"outputs": [],
"source": [
"version = \"0.8.0\" # x-release-please-version\n",
"version = \"0.7.0\" # x-release-please-version\n",
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
"\n",
"# Make the binary executable\n",

View File

@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

View File

@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
<!-- {x-release-please-start-version} -->
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O https://storage.googleapis.com/genai-toolbox/v0.8.0/$OS/toolbox
curl -O https://storage.googleapis.com/genai-toolbox/v0.7.0/$OS/toolbox
```
<!-- {x-release-please-end} -->

4
go.mod
View File

@@ -7,7 +7,7 @@ toolchain go1.24.4
require (
cloud.google.com/go/alloydbconn v1.15.3
cloud.google.com/go/bigquery v1.69.0
cloud.google.com/go/bigtable v1.38.0
cloud.google.com/go/bigtable v1.37.0
cloud.google.com/go/cloudsqlconn v1.17.2
cloud.google.com/go/spanner v1.83.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
@@ -17,7 +17,7 @@ require (
github.com/go-chi/chi/v5 v5.2.2
github.com/go-chi/httplog/v2 v2.1.1
github.com/go-chi/render v1.0.3
github.com/go-playground/validator/v10 v10.27.0
github.com/go-playground/validator/v10 v10.26.0
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-yaml v1.18.0
github.com/google/go-cmp v0.7.0

8
go.sum
View File

@@ -139,8 +139,8 @@ cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9
cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU=
cloud.google.com/go/bigquery v1.69.0 h1:rZvHnjSUs5sHK3F9awiuFk2PeOaB8suqNuim21GbaTc=
cloud.google.com/go/bigquery v1.69.0/go.mod h1:TdGLquA3h/mGg+McX+GsqG9afAzTAcldMjqhdjHTLew=
cloud.google.com/go/bigtable v1.38.0 h1:L/PnUXRtAzFfa7qMULJHt4cXa/O2dqPJEkzYNGA4hfo=
cloud.google.com/go/bigtable v1.38.0/go.mod h1:o/lntJarF3Y5C0XYLMJLjLYwxaRbcrtM0BiV57ymXbI=
cloud.google.com/go/bigtable v1.37.0 h1:Q+x7y04lQ0B+WXp03wc1/FLhFt4CwcQdkwWT0M4Jp3w=
cloud.google.com/go/bigtable v1.37.0/go.mod h1:HXqddP6hduwzrtiTCqZPpj9ij4hGZb4Zy1WF/dT+yaU=
cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY=
cloud.google.com/go/billing v1.5.0/go.mod h1:mztb1tBc3QekhjSgmpf/CV4LzWXLzCArwpLmP2Gm88s=
cloud.google.com/go/billing v1.6.0/go.mod h1:WoXzguj+BeHXPbKfNWkqVtDdzORazmCjraY+vrxcyvI=
@@ -801,8 +801,8 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc/iMaVtFbr3Sw2k=
github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=

View File

@@ -319,6 +319,13 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
if err := dec.DecodeContext(ctx, a); err != nil {
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
}
if a.Default != nil {
intD, ok := a.Default.(uint64)
if !ok {
return nil, fmt.Errorf("default value fail to assert to type int64")
}
a.Default = int(intD)
}
if a.AuthSources != nil {
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
a.AuthServices = append(a.AuthServices, a.AuthSources...)
@@ -411,6 +418,7 @@ type ParameterMcpManifest struct {
type CommonParameter struct {
Name string `yaml:"name" validate:"required"`
Type string `yaml:"type" validate:"required"`
Default any `yaml:"default"`
Desc string `yaml:"description" validate:"required"`
AuthServices []ParamAuthService `yaml:"authServices"`
AuthSources []ParamAuthService `yaml:"authSources"` // Deprecated: Kept for compatibility.
@@ -426,6 +434,30 @@ func (p *CommonParameter) GetType() string {
return p.Type
}
func (p *CommonParameter) GetDefault() any {
return p.Default
}
// Manifest returns the manifest for the Parameter.
func (p *CommonParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
var required bool
if p.Default == nil {
required = true
}
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Required: required,
Description: p.Desc,
AuthServices: authNames,
}
}
// McpManifest returns the MCP manifest for the Parameter.
func (p *CommonParameter) McpManifest() ParameterMcpManifest {
return ParameterMcpManifest{
@@ -468,10 +500,10 @@ func NewStringParameterWithDefault(name string, defaultV, desc string) *StringPa
CommonParameter: CommonParameter{
Name: name,
Type: typeString,
Default: defaultV,
Desc: desc,
AuthServices: nil,
},
Default: &defaultV,
}
}
@@ -492,7 +524,6 @@ var _ Parameter = &StringParameter{}
// StringParameter is a parameter representing the "string" type.
type StringParameter struct {
CommonParameter `yaml:",inline"`
Default *string `yaml:"default"`
}
// Parse casts the value "v" as a "string".
@@ -503,35 +534,10 @@ func (p *StringParameter) Parse(v any) (any, error) {
}
return newV, nil
}
func (p *StringParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *StringParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
// Manifest returns the manifest for the StringParameter.
func (p *StringParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
required := p.Default == nil
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Required: required,
Description: p.Desc,
AuthServices: authNames,
}
}
// NewIntParameter is a convenience function for initializing a IntParameter.
func NewIntParameter(name string, desc string) *IntParameter {
return &IntParameter{
@@ -545,15 +551,15 @@ func NewIntParameter(name string, desc string) *IntParameter {
}
// NewIntParameterWithDefault is a convenience function for initializing a IntParameter with default value.
func NewIntParameterWithDefault(name string, defaultV int, desc string) *IntParameter {
func NewIntParameterWithDefault(name string, defaultV any, desc string) *IntParameter {
return &IntParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeInt,
Default: defaultV,
Desc: desc,
AuthServices: nil,
},
Default: &defaultV,
}
}
@@ -574,7 +580,6 @@ var _ Parameter = &IntParameter{}
// IntParameter is a parameter representing the "int" type.
type IntParameter struct {
CommonParameter `yaml:",inline"`
Default *int `yaml:"default"`
}
func (p *IntParameter) Parse(v any) (any, error) {
@@ -602,30 +607,6 @@ func (p *IntParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *IntParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
// Manifest returns the manifest for the IntParameter.
func (p *IntParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
required := p.Default == nil
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Required: required,
Description: p.Desc,
AuthServices: authNames,
}
}
// NewFloatParameter is a convenience function for initializing a FloatParameter.
func NewFloatParameter(name string, desc string) *FloatParameter {
return &FloatParameter{
@@ -644,10 +625,10 @@ func NewFloatParameterWithDefault(name string, defaultV float64, desc string) *F
CommonParameter: CommonParameter{
Name: name,
Type: typeFloat,
Default: defaultV,
Desc: desc,
AuthServices: nil,
},
Default: &defaultV,
}
}
@@ -668,7 +649,6 @@ var _ Parameter = &FloatParameter{}
// FloatParameter is a parameter representing the "float" type.
type FloatParameter struct {
CommonParameter `yaml:",inline"`
Default *float64 `yaml:"default"`
}
func (p *FloatParameter) Parse(v any) (any, error) {
@@ -694,30 +674,6 @@ func (p *FloatParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *FloatParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
// Manifest returns the manifest for the FloatParameter.
func (p *FloatParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
required := p.Default == nil
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Required: required,
Description: p.Desc,
AuthServices: authNames,
}
}
// NewBooleanParameter is a convenience function for initializing a BooleanParameter.
func NewBooleanParameter(name string, desc string) *BooleanParameter {
return &BooleanParameter{
@@ -736,10 +692,10 @@ func NewBooleanParameterWithDefault(name string, defaultV bool, desc string) *Bo
CommonParameter: CommonParameter{
Name: name,
Type: typeBool,
Default: defaultV,
Desc: desc,
AuthServices: nil,
},
Default: &defaultV,
}
}
@@ -760,7 +716,6 @@ var _ Parameter = &BooleanParameter{}
// BooleanParameter is a parameter representing the "boolean" type.
type BooleanParameter struct {
CommonParameter `yaml:",inline"`
Default *bool `yaml:"default"`
}
func (p *BooleanParameter) Parse(v any) (any, error) {
@@ -775,30 +730,6 @@ func (p *BooleanParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *BooleanParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
// Manifest returns the manifest for the BooleanParameter.
func (p *BooleanParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
authNames := make([]string, len(p.AuthServices))
for i, a := range p.AuthServices {
authNames[i] = a.Name
}
required := p.Default == nil
return ParameterManifest{
Name: p.Name,
Type: p.Type,
Required: required,
Description: p.Desc,
AuthServices: authNames,
}
}
// NewArrayParameter is a convenience function for initializing a ArrayParameter.
func NewArrayParameter(name string, desc string, items Parameter) *ArrayParameter {
return &ArrayParameter{
@@ -813,16 +744,16 @@ func NewArrayParameter(name string, desc string, items Parameter) *ArrayParamete
}
// NewArrayParameterWithDefault is a convenience function for initializing a ArrayParameter with default value.
func NewArrayParameterWithDefault(name string, defaultV []any, desc string, items Parameter) *ArrayParameter {
func NewArrayParameterWithDefault(name string, defaultV any, desc string, items Parameter) *ArrayParameter {
return &ArrayParameter{
CommonParameter: CommonParameter{
Name: name,
Type: typeArray,
Default: defaultV,
Desc: desc,
AuthServices: nil,
},
Items: items,
Default: &defaultV,
Items: items,
}
}
@@ -844,21 +775,18 @@ var _ Parameter = &ArrayParameter{}
// ArrayParameter is a parameter representing the "array" type.
type ArrayParameter struct {
CommonParameter `yaml:",inline"`
Default *[]any `yaml:"default"`
Items Parameter `yaml:"items"`
}
func (p *ArrayParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
var rawItem struct {
CommonParameter `yaml:",inline"`
Default *[]any `yaml:"default"`
Items util.DelayedUnmarshaler `yaml:"items"`
}
if err := unmarshal(&rawItem); err != nil {
return err
}
p.CommonParameter = rawItem.CommonParameter
p.Default = rawItem.Default
i, err := parseParamFromDelayedUnmarshaler(ctx, &rawItem.Items)
if err != nil {
return fmt.Errorf("unable to parse 'items' field: %w", err)
@@ -891,13 +819,6 @@ func (p *ArrayParameter) GetAuthServices() []ParamAuthService {
return p.AuthServices
}
func (p *ArrayParameter) GetDefault() any {
if p.Default == nil {
return nil
}
return *p.Default
}
// Manifest returns the manifest for the ArrayParameter.
func (p *ArrayParameter) Manifest() ParameterManifest {
// only list ParamAuthService names (without fields) in manifest
@@ -906,8 +827,11 @@ func (p *ArrayParameter) Manifest() ParameterManifest {
authNames[i] = a.Name
}
items := p.Items.Manifest()
required := p.Default == nil
items.Required = required
var required bool
if p.Default == nil {
required = true
items.Required = true
}
return ParameterManifest{
Name: p.Name,
Type: p.Type,

View File

@@ -187,7 +187,7 @@ func TestParametersMarshal(t *testing.T) {
{
"name": "my_array",
"type": "array",
"default": []any{"foo", "bar"},
"default": `["foo", "bar"]`,
"description": "this param is an array of strings",
"items": map[string]string{
"name": "my_string",
@@ -197,7 +197,7 @@ func TestParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewArrayParameterWithDefault("my_array", []any{"foo", "bar"}, "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
tools.NewArrayParameterWithDefault("my_array", `["foo", "bar"]`, "this param is an array of strings", tools.NewStringParameter("my_string", "string item")),
},
},
{
@@ -206,7 +206,7 @@ func TestParametersMarshal(t *testing.T) {
{
"name": "my_array",
"type": "array",
"default": []any{1.0, 1.1},
"default": "[1.0, 1.1]",
"description": "this param is an array of floats",
"items": map[string]string{
"name": "my_float",
@@ -216,7 +216,7 @@ func TestParametersMarshal(t *testing.T) {
},
},
want: tools.Parameters{
tools.NewArrayParameterWithDefault("my_array", []any{1.0, 1.1}, "this param is an array of floats", tools.NewFloatParameter("my_float", "float item")),
tools.NewArrayParameterWithDefault("my_array", "[1.0, 1.1]", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item")),
},
},
}
@@ -993,14 +993,14 @@ func TestParamManifest(t *testing.T) {
},
{
name: "array default",
in: tools.NewArrayParameterWithDefault("foo-array", []any{"foo", "bar"}, "bar", tools.NewStringParameter("foo-string", "bar")),
in: tools.NewArrayParameterWithDefault("foo-array", `["foo", "bar"]`, "bar", tools.NewStringParameter("foo-string", "bar")),
want: tools.ParameterManifest{
Name: "foo-array",
Type: "array",
Required: false,
Description: "bar",
AuthServices: []string{},
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}},
},
},
}
@@ -1073,7 +1073,7 @@ func TestMcpManifest(t *testing.T) {
tools.NewStringParameter("foo-string2", "bar"),
tools.NewIntParameterWithDefault("foo-int", 1, "bar"),
tools.NewIntParameter("foo-int2", "bar"),
tools.NewArrayParameterWithDefault("foo-array", []any{"hello", "world"}, "bar", tools.NewStringParameter("foo-string", "bar")),
tools.NewArrayParameterWithDefault("foo-array", []string{"hello", "world"}, "bar", tools.NewStringParameter("foo-string", "bar")),
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
},
want: tools.McpToolsSchema{

View File

@@ -32,55 +32,55 @@ import (
)
var (
AlloyDBAINLSourceKind = "alloydb-postgres"
AlloyDBAINLToolKind = "alloydb-ai-nl"
AlloyDBAINLProject = os.Getenv("ALLOYDB_AI_NL_PROJECT")
AlloyDBAINLRegion = os.Getenv("ALLOYDB_AI_NL_REGION")
AlloyDBAINLCluster = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
AlloyDBAINLInstance = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
AlloyDBAINLDatabase = os.Getenv("ALLOYDB_AI_NL_DATABASE")
AlloyDBAINLUser = os.Getenv("ALLOYDB_AI_NL_USER")
AlloyDBAINLPass = os.Getenv("ALLOYDB_AI_NL_PASS")
ALLOYDB_AI_NL_SOURCE_KIND = "alloydb-postgres"
ALLOYDB_AI_NL_TOOL_KIND = "alloydb-ai-nl"
ALLOYDB_AI_NL_PROJECT = os.Getenv("ALLOYDB_AI_NL_PROJECT")
ALLOYDB_AI_NL_REGION = os.Getenv("ALLOYDB_AI_NL_REGION")
ALLOYDB_AI_NL_CLUSTER = os.Getenv("ALLOYDB_AI_NL_CLUSTER")
ALLOYDB_AI_NL_INSTANCE = os.Getenv("ALLOYDB_AI_NL_INSTANCE")
ALLOYDB_AI_NL_DATABASE = os.Getenv("ALLOYDB_AI_NL_DATABASE")
ALLOYDB_AI_NL_USER = os.Getenv("ALLOYDB_AI_NL_USER")
ALLOYDB_AI_NL_PASS = os.Getenv("ALLOYDB_AI_NL_PASS")
)
func getAlloyDBAINLVars(t *testing.T) map[string]any {
func getAlloyDBAiNlVars(t *testing.T) map[string]any {
switch "" {
case AlloyDBAINLProject:
case ALLOYDB_AI_NL_PROJECT:
t.Fatal("'ALLOYDB_AI_NL_PROJECT' not set")
case AlloyDBAINLRegion:
case ALLOYDB_AI_NL_REGION:
t.Fatal("'ALLOYDB_AI_NL_REGION' not set")
case AlloyDBAINLCluster:
case ALLOYDB_AI_NL_CLUSTER:
t.Fatal("'ALLOYDB_AI_NL_CLUSTER' not set")
case AlloyDBAINLInstance:
case ALLOYDB_AI_NL_INSTANCE:
t.Fatal("'ALLOYDB_AI_NL_INSTANCE' not set")
case AlloyDBAINLDatabase:
case ALLOYDB_AI_NL_DATABASE:
t.Fatal("'ALLOYDB_AI_NL_DATABASE' not set")
case AlloyDBAINLUser:
case ALLOYDB_AI_NL_USER:
t.Fatal("'ALLOYDB_AI_NL_USER' not set")
case AlloyDBAINLPass:
case ALLOYDB_AI_NL_PASS:
t.Fatal("'ALLOYDB_AI_NL_PASS' not set")
}
return map[string]any{
"kind": AlloyDBAINLSourceKind,
"project": AlloyDBAINLProject,
"cluster": AlloyDBAINLCluster,
"instance": AlloyDBAINLInstance,
"region": AlloyDBAINLRegion,
"database": AlloyDBAINLDatabase,
"user": AlloyDBAINLUser,
"password": AlloyDBAINLPass,
"kind": ALLOYDB_AI_NL_SOURCE_KIND,
"project": ALLOYDB_AI_NL_PROJECT,
"cluster": ALLOYDB_AI_NL_CLUSTER,
"instance": ALLOYDB_AI_NL_INSTANCE,
"region": ALLOYDB_AI_NL_REGION,
"database": ALLOYDB_AI_NL_DATABASE,
"user": ALLOYDB_AI_NL_USER,
"password": ALLOYDB_AI_NL_PASS,
}
}
func TestAlloyDBAINLToolEndpoints(t *testing.T) {
sourceConfig := getAlloyDBAINLVars(t)
func TestAlloyDBAiNlToolEndpoints(t *testing.T) {
sourceConfig := getAlloyDBAiNlVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
// Write config into a file and pass it to command
toolsFile := getAINLToolsConfig(sourceConfig)
toolsFile := getAiNlToolsConfig(sourceConfig)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -96,12 +96,12 @@ func TestAlloyDBAINLToolEndpoints(t *testing.T) {
t.Fatalf("toolbox didn't start successfully: %s", err)
}
runAINLToolGetTest(t)
runAINLToolInvokeTest(t)
runAINLMCPToolCallMethod(t)
runAiNlToolGetTest(t)
runAiNlToolInvokeTest(t)
runAiNlMCPToolCallMethod(t)
}
func runAINLToolGetTest(t *testing.T) {
func runAiNlToolGetTest(t *testing.T) {
// Test tool get endpoint
tcs := []struct {
name string
@@ -156,7 +156,7 @@ func runAINLToolGetTest(t *testing.T) {
}
}
func runAINLToolInvokeTest(t *testing.T) {
func runAiNlToolInvokeTest(t *testing.T) {
// Get ID token
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
@@ -276,7 +276,7 @@ func runAINLToolInvokeTest(t *testing.T) {
}
func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
func getAiNlToolsConfig(sourceConfig map[string]any) map[string]any {
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
@@ -290,13 +290,13 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": AlloyDBAINLToolKind,
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Simple tool to test end to end functionality.",
"nlConfig": "my_nl_config",
},
"my-auth-tool": map[string]any{
"kind": AlloyDBAINLToolKind,
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Tool to test authenticated parameters.",
"nlConfig": "my_nl_config",
@@ -315,7 +315,7 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
},
},
"my-auth-required-tool": map[string]any{
"kind": AlloyDBAINLToolKind,
"kind": ALLOYDB_AI_NL_TOOL_KIND,
"source": "my-instance",
"description": "Tool to test auth required invocation.",
"nlConfig": "my_nl_config",
@@ -329,7 +329,7 @@ func getAINLToolsConfig(sourceConfig map[string]any) map[string]any {
return toolsFile
}
func runAINLMCPToolCallMethod(t *testing.T) {
func runAiNlMCPToolCallMethod(t *testing.T) {
sessionId := tests.RunInitialize(t, "2024-11-05")
header := map[string]string{}
if sessionId != "" {

View File

@@ -31,60 +31,60 @@ import (
)
var (
AlloyDBPostgresSourceKind = "alloydb-postgres"
AlloyDBPostgresToolKind = "postgres-sql"
AlloyDBPostgresProject = os.Getenv("ALLOYDB_POSTGRES_PROJECT")
AlloyDBPostgresRegion = os.Getenv("ALLOYDB_POSTGRES_REGION")
AlloyDBPostgresCluster = os.Getenv("ALLOYDB_POSTGRES_CLUSTER")
AlloyDBPostgresInstance = os.Getenv("ALLOYDB_POSTGRES_INSTANCE")
AlloyDBPostgresDatabase = os.Getenv("ALLOYDB_POSTGRES_DATABASE")
AlloyDBPostgresUser = os.Getenv("ALLOYDB_POSTGRES_USER")
AlloyDBPostgresPass = os.Getenv("ALLOYDB_POSTGRES_PASS")
ALLOYDB_POSTGRES_SOURCE_KIND = "alloydb-postgres"
ALLOYDB_POSTGRES_TOOL_KIND = "postgres-sql"
ALLOYDB_POSTGRES_PROJECT = os.Getenv("ALLOYDB_POSTGRES_PROJECT")
ALLOYDB_POSTGRES_REGION = os.Getenv("ALLOYDB_POSTGRES_REGION")
ALLOYDB_POSTGRES_CLUSTER = os.Getenv("ALLOYDB_POSTGRES_CLUSTER")
ALLOYDB_POSTGRES_INSTANCE = os.Getenv("ALLOYDB_POSTGRES_INSTANCE")
ALLOYDB_POSTGRES_DATABASE = os.Getenv("ALLOYDB_POSTGRES_DATABASE")
ALLOYDB_POSTGRES_USER = os.Getenv("ALLOYDB_POSTGRES_USER")
ALLOYDB_POSTGRES_PASS = os.Getenv("ALLOYDB_POSTGRES_PASS")
)
func getAlloyDBPgVars(t *testing.T) map[string]any {
switch "" {
case AlloyDBPostgresProject:
case ALLOYDB_POSTGRES_PROJECT:
t.Fatal("'ALLOYDB_POSTGRES_PROJECT' not set")
case AlloyDBPostgresRegion:
case ALLOYDB_POSTGRES_REGION:
t.Fatal("'ALLOYDB_POSTGRES_REGION' not set")
case AlloyDBPostgresCluster:
case ALLOYDB_POSTGRES_CLUSTER:
t.Fatal("'ALLOYDB_POSTGRES_CLUSTER' not set")
case AlloyDBPostgresInstance:
case ALLOYDB_POSTGRES_INSTANCE:
t.Fatal("'ALLOYDB_POSTGRES_INSTANCE' not set")
case AlloyDBPostgresDatabase:
case ALLOYDB_POSTGRES_DATABASE:
t.Fatal("'ALLOYDB_POSTGRES_DATABASE' not set")
case AlloyDBPostgresUser:
case ALLOYDB_POSTGRES_USER:
t.Fatal("'ALLOYDB_POSTGRES_USER' not set")
case AlloyDBPostgresPass:
case ALLOYDB_POSTGRES_PASS:
t.Fatal("'ALLOYDB_POSTGRES_PASS' not set")
}
return map[string]any{
"kind": AlloyDBPostgresSourceKind,
"project": AlloyDBPostgresProject,
"cluster": AlloyDBPostgresCluster,
"instance": AlloyDBPostgresInstance,
"region": AlloyDBPostgresRegion,
"database": AlloyDBPostgresDatabase,
"user": AlloyDBPostgresUser,
"password": AlloyDBPostgresPass,
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
"project": ALLOYDB_POSTGRES_PROJECT,
"cluster": ALLOYDB_POSTGRES_CLUSTER,
"instance": ALLOYDB_POSTGRES_INSTANCE,
"region": ALLOYDB_POSTGRES_REGION,
"database": ALLOYDB_POSTGRES_DATABASE,
"user": ALLOYDB_POSTGRES_USER,
"password": ALLOYDB_POSTGRES_PASS,
}
}
// Copied over from alloydb_pg.go
func getAlloyDBDialOpts(ipType string) ([]alloydbconn.DialOption, error) {
switch strings.ToLower(ipType) {
func getAlloyDBDialOpts(ip_type string) ([]alloydbconn.DialOption, error) {
switch strings.ToLower(ip_type) {
case "private":
return []alloydbconn.DialOption{alloydbconn.WithPrivateIP()}, nil
case "public":
return []alloydbconn.DialOption{alloydbconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
return nil, fmt.Errorf("invalid ip_type %s", ip_type)
}
}
// Copied over from alloydb_pg.go
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ipType, user, pass, dbname string) (*pgxpool.Pool, error) {
func initAlloyDBPgConnectionPool(project, region, cluster, instance, ip_type, user, pass, dbname string) (*pgxpool.Pool, error) {
// Configure the driver to connect to the database
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
config, err := pgxpool.ParseConfig(dsn)
@@ -93,7 +93,7 @@ func initAlloyDBPgConnectionPool(project, region, cluster, instance, ipType, use
}
// Create a new dialer with options
dialOpts, err := getAlloyDBDialOpts(ipType)
dialOpts, err := getAlloyDBDialOpts(ip_type)
if err != nil {
return nil, err
}
@@ -123,7 +123,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
var args []string
pool, err := initAlloyDBPgConnectionPool(AlloyDBPostgresProject, AlloyDBPostgresRegion, AlloyDBPostgresCluster, AlloyDBPostgresInstance, "public", AlloyDBPostgresUser, AlloyDBPostgresPass, AlloyDBPostgresDatabase)
pool, err := initAlloyDBPgConnectionPool(ALLOYDB_POSTGRES_PROJECT, ALLOYDB_POSTGRES_REGION, ALLOYDB_POSTGRES_CLUSTER, ALLOYDB_POSTGRES_INSTANCE, "public", ALLOYDB_POSTGRES_USER, ALLOYDB_POSTGRES_PASS, ALLOYDB_POSTGRES_DATABASE)
if err != nil {
t.Fatalf("unable to create AlloyDB connection pool: %s", err)
}
@@ -134,20 +134,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, ALLOYDB_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -193,7 +193,7 @@ func TestAlloyDBPgIpConnection(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
sourceConfig["ipType"] = tc.ipType
err := tests.RunSourceConnectionTest(t, sourceConfig, AlloyDBPostgresToolKind)
err := tests.RunSourceConnectionTest(t, sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
if err != nil {
t.Fatalf("Connection test failure: %s", err)
}
@@ -205,35 +205,35 @@ func TestAlloyDBPgIpConnection(t *testing.T) {
func TestAlloyDBPgIAMConnection(t *testing.T) {
getAlloyDBPgVars(t)
// service account email used for IAM should trim the suffix
serviceAccountEmail := strings.TrimSuffix(tests.ServiceAccountEmail, ".gserviceaccount.com")
serviceAccountEmail := strings.TrimSuffix(tests.SERVICE_ACCOUNT_EMAIL, ".gserviceaccount.com")
noPassSourceConfig := map[string]any{
"kind": AlloyDBPostgresSourceKind,
"project": AlloyDBPostgresProject,
"cluster": AlloyDBPostgresCluster,
"instance": AlloyDBPostgresInstance,
"region": AlloyDBPostgresRegion,
"database": AlloyDBPostgresDatabase,
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
"project": ALLOYDB_POSTGRES_PROJECT,
"cluster": ALLOYDB_POSTGRES_CLUSTER,
"instance": ALLOYDB_POSTGRES_INSTANCE,
"region": ALLOYDB_POSTGRES_REGION,
"database": ALLOYDB_POSTGRES_DATABASE,
"user": serviceAccountEmail,
}
noUserSourceConfig := map[string]any{
"kind": AlloyDBPostgresSourceKind,
"project": AlloyDBPostgresProject,
"cluster": AlloyDBPostgresCluster,
"instance": AlloyDBPostgresInstance,
"region": AlloyDBPostgresRegion,
"database": AlloyDBPostgresDatabase,
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
"project": ALLOYDB_POSTGRES_PROJECT,
"cluster": ALLOYDB_POSTGRES_CLUSTER,
"instance": ALLOYDB_POSTGRES_INSTANCE,
"region": ALLOYDB_POSTGRES_REGION,
"database": ALLOYDB_POSTGRES_DATABASE,
"password": "random",
}
noUserNoPassSourceConfig := map[string]any{
"kind": AlloyDBPostgresSourceKind,
"project": AlloyDBPostgresProject,
"cluster": AlloyDBPostgresCluster,
"instance": AlloyDBPostgresInstance,
"region": AlloyDBPostgresRegion,
"database": AlloyDBPostgresDatabase,
"kind": ALLOYDB_POSTGRES_SOURCE_KIND,
"project": ALLOYDB_POSTGRES_PROJECT,
"cluster": ALLOYDB_POSTGRES_CLUSTER,
"instance": ALLOYDB_POSTGRES_INSTANCE,
"region": ALLOYDB_POSTGRES_REGION,
"database": ALLOYDB_POSTGRES_DATABASE,
}
tcs := []struct {
name string
@@ -258,7 +258,7 @@ func TestAlloyDBPgIAMConnection(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, AlloyDBPostgresToolKind)
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, ALLOYDB_POSTGRES_TOOL_KIND)
if err != nil {
if tc.isErr {
return

View File

@@ -23,7 +23,7 @@ import (
"google.golang.org/api/idtoken"
)
var ServiceAccountEmail = os.Getenv("SERVICE_ACCOUNT_EMAIL")
var SERVICE_ACCOUNT_EMAIL = os.Getenv("SERVICE_ACCOUNT_EMAIL")
var ClientId = os.Getenv("CLIENT_ID")
// GetGoogleIdToken retrieve and return the Google ID token

View File

@@ -37,20 +37,20 @@ import (
)
var (
BigquerySourceKind = "bigquery"
BigqueryToolKind = "bigquery-sql"
BigqueryProject = os.Getenv("BIGQUERY_PROJECT")
BIGQUERY_SOURCE_KIND = "bigquery"
BIGQUERY_TOOL_KIND = "bigquery-sql"
BIGQUERY_PROJECT = os.Getenv("BIGQUERY_PROJECT")
)
func getBigQueryVars(t *testing.T) map[string]any {
switch "" {
case BigqueryProject:
case BIGQUERY_PROJECT:
t.Fatal("'BIGQUERY_PROJECT' not set")
}
return map[string]any{
"kind": BigquerySourceKind,
"project": BigqueryProject,
"kind": BIGQUERY_SOURCE_KIND,
"project": BIGQUERY_PROJECT,
}
}
@@ -76,7 +76,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
var args []string
client, err := initBigQueryConnection(BigqueryProject)
client, err := initBigQueryConnection(BIGQUERY_PROJECT)
if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
@@ -85,36 +85,36 @@ func TestBigQueryToolEndpoints(t *testing.T) {
datasetName := fmt.Sprintf("temp_toolbox_test_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableName := fmt.Sprintf("param_table_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
tableNameParam := fmt.Sprintf("`%s.%s.%s`",
BigqueryProject,
BIGQUERY_PROJECT,
datasetName,
tableName,
)
tableNameAuth := fmt.Sprintf("`%s.%s.auth_table_%s`",
BigqueryProject,
BIGQUERY_PROJECT,
datasetName,
strings.ReplaceAll(uuid.New().String(), "-", ""),
)
tableNameTemplateParam := fmt.Sprintf("`%s.%s.template_param_table_%s`",
BigqueryProject,
BIGQUERY_PROJECT,
datasetName,
strings.ReplaceAll(uuid.New().String(), "-", ""),
)
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := getBigQueryParamToolInfo(tableNameParam)
teardownTable1 := setupBigQueryTable(t, ctx, client, createStatement1, insertStatement1, datasetName, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := getBigQueryParamToolInfo(tableNameParam)
teardownTable1 := setupBigQueryTable(t, ctx, client, create_statement1, insert_statement1, datasetName, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := getBigQueryAuthToolInfo(tableNameAuth)
teardownTable2 := setupBigQueryTable(t, ctx, client, createStatement2, insertStatement2, datasetName, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := getBigQueryAuthToolInfo(tableNameAuth)
teardownTable2 := setupBigQueryTable(t, ctx, client, create_statement2, insert_statement2, datasetName, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, BIGQUERY_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BIGQUERY_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -176,7 +176,7 @@ func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigque
toolStatement := fmt.Sprintf(`
SELECT name FROM %s WHERE email = ?`, tableName)
params := []bigqueryapi.QueryParameter{
{Value: int64(1)}, {Value: "Alice"}, {Value: tests.ServiceAccountEmail},
{Value: int64(1)}, {Value: "Alice"}, {Value: tests.SERVICE_ACCOUNT_EMAIL},
{Value: int64(2)}, {Value: "Jane"}, {Value: "janedoe@gmail.com"},
}
return createStatement, insertStatement, toolStatement, params
@@ -189,7 +189,7 @@ func getBigQueryTmplToolStatement() (string, string) {
return tmplSelectCombined, tmplSelectFilterCombined
}
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, createStatement, insertStatement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, create_statement, insert_statement, datasetName string, tableName string, params []bigqueryapi.QueryParameter) func(*testing.T) {
// Create dataset
dataset := client.Dataset(datasetName)
_, err := dataset.Metadata(ctx)
@@ -206,7 +206,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
}
// Create table
createJob, err := client.Query(createStatement).Run(ctx)
createJob, err := client.Query(create_statement).Run(ctx)
if err != nil {
t.Fatalf("Failed to start create table job for %s: %v", tableName, err)
@@ -220,7 +220,7 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
}
// Insert test data
insertQuery := client.Query(insertStatement)
insertQuery := client.Query(insert_statement)
insertQuery.Parameters = params
insertJob, err := insertQuery.Run(ctx)
if err != nil {
@@ -340,7 +340,7 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
return config
}
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam string) {
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select_1_want, invokeParamWant, tableNameParam string) {
// Get ID token
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
@@ -368,7 +368,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
want: select1Want,
want: select_1_want,
isErr: false,
},
{
@@ -414,7 +414,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
isErr: false,
want: select1Want,
want: select_1_want,
},
{
name: "Invoke my-auth-exec-sql-tool with invalid auth token",

View File

@@ -34,24 +34,24 @@ import (
)
var (
BigtableSourceKind = "bigtable"
BigtableToolKind = "bigtable-sql"
BigtableProject = os.Getenv("BIGTABLE_PROJECT")
BigtableInstance = os.Getenv("BIGTABLE_INSTANCE")
BIGTABLE_SOURCE_KIND = "bigtable"
BIGTABLE_TOOL_KIND = "bigtable-sql"
BIGTABLE_PROJECT = os.Getenv("BIGTABLE_PROJECT")
BIGTABLE_INSTANCE = os.Getenv("BIGTABLE_INSTANCE")
)
func getBigtableVars(t *testing.T) map[string]any {
switch "" {
case BigtableProject:
case BIGTABLE_PROJECT:
t.Fatal("'BIGTABLE_PROJECT' not set")
case BigtableInstance:
case BIGTABLE_INSTANCE:
t.Fatal("'BIGTABLE_INSTANCE' not set")
}
return map[string]any{
"kind": BigtableSourceKind,
"project": BigtableProject,
"instance": BigtableInstance,
"kind": BIGTABLE_SOURCE_KIND,
"project": BIGTABLE_PROJECT,
"instance": BIGTABLE_INSTANCE,
}
}
@@ -77,13 +77,13 @@ func TestBigtableToolEndpoints(t *testing.T) {
// Do not change the shape of statement without checking tests/common_test.go.
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
paramTestStatement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id OR CAST(cf['name'] AS string) = @name;", tableName)
param_test_statement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id OR CAST(cf['name'] AS string) = @name;", tableName)
teardownTable1 := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableName, columnFamilyName, muts, rowKeys)
defer teardownTable1(t)
// Do not change the shape of statement without checking tests/common_test.go.
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
authToolStatement := fmt.Sprintf("SELECT CAST(cf['name'] AS string) as name FROM %s WHERE CAST(cf['email'] AS string) = @email;", tableNameAuth)
auth_tool_statement := fmt.Sprintf("SELECT CAST(cf['name'] AS string) as name FROM %s WHERE CAST(cf['email'] AS string) = @email;", tableNameAuth)
teardownTable2 := setupBtTable(t, ctx, sourceConfig["project"].(string), sourceConfig["instance"].(string), tableNameAuth, columnFamilyName, muts, rowKeys)
defer teardownTable2(t)
@@ -92,7 +92,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
defer teardownTableTmpl(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, BigtableToolKind, paramTestStatement, authToolStatement)
toolsFile := tests.GetToolsConfig(sourceConfig, BIGTABLE_TOOL_KIND, param_test_statement, auth_tool_statement)
toolsFile = addTemplateParamConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
@@ -151,7 +151,7 @@ func getTestData(columnFamilyName string) ([]*bigtable.Mutation, []string) {
// Expected values are defined in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L229-L310
"row-01": {
"name": []byte("Alice"),
"email": []byte(tests.ServiceAccountEmail),
"email": []byte(tests.SERVICE_ACCOUNT_EMAIL),
"id": ids[0],
},
"row-02": {

View File

@@ -33,49 +33,49 @@ import (
)
var (
CloudSQLMSSQLSourceKind = "cloud-sql-mssql"
CloudSQLMSSQLToolKind = "mssql-sql"
CloudSQLMSSQLProject = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
CloudSQLMSSQLRegion = os.Getenv("CLOUD_SQL_MSSQL_REGION")
CloudSQLMSSQLInstance = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
CloudSQLMSSQLDatabase = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
CloudSQLMSSQLIp = os.Getenv("CLOUD_SQL_MSSQL_IP")
CloudSQLMSSQLUser = os.Getenv("CLOUD_SQL_MSSQL_USER")
CloudSQLMSSQLPass = os.Getenv("CLOUD_SQL_MSSQL_PASS")
CLOUD_SQL_MSSQL_SOURCE_KIND = "cloud-sql-mssql"
CLOUD_SQL_MSSQL_TOOL_KIND = "mssql-sql"
CLOUD_SQL_MSSQL_PROJECT = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
CLOUD_SQL_MSSQL_REGION = os.Getenv("CLOUD_SQL_MSSQL_REGION")
CLOUD_SQL_MSSQL_INSTANCE = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
CLOUD_SQL_MSSQL_DATABASE = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
CLOUD_SQL_MSSQL_IP = os.Getenv("CLOUD_SQL_MSSQL_IP")
CLOUD_SQL_MSSQL_USER = os.Getenv("CLOUD_SQL_MSSQL_USER")
CLOUD_SQL_MSSQL_PASS = os.Getenv("CLOUD_SQL_MSSQL_PASS")
)
func getCloudSQLMSSQLVars(t *testing.T) map[string]any {
func getCloudSQLMssqlVars(t *testing.T) map[string]any {
switch "" {
case CloudSQLMSSQLProject:
case CLOUD_SQL_MSSQL_PROJECT:
t.Fatal("'CLOUD_SQL_MSSQL_PROJECT' not set")
case CloudSQLMSSQLRegion:
case CLOUD_SQL_MSSQL_REGION:
t.Fatal("'CLOUD_SQL_MSSQL_REGION' not set")
case CloudSQLMSSQLInstance:
case CLOUD_SQL_MSSQL_INSTANCE:
t.Fatal("'CLOUD_SQL_MSSQL_INSTANCE' not set")
case CloudSQLMSSQLIp:
case CLOUD_SQL_MSSQL_IP:
t.Fatal("'CLOUD_SQL_MSSQL_IP' not set")
case CloudSQLMSSQLDatabase:
case CLOUD_SQL_MSSQL_DATABASE:
t.Fatal("'CLOUD_SQL_MSSQL_DATABASE' not set")
case CloudSQLMSSQLUser:
case CLOUD_SQL_MSSQL_USER:
t.Fatal("'CLOUD_SQL_MSSQL_USER' not set")
case CloudSQLMSSQLPass:
case CLOUD_SQL_MSSQL_PASS:
t.Fatal("'CLOUD_SQL_MSSQL_PASS' not set")
}
return map[string]any{
"kind": CloudSQLMSSQLSourceKind,
"project": CloudSQLMSSQLProject,
"instance": CloudSQLMSSQLInstance,
"ipAddress": CloudSQLMSSQLIp,
"region": CloudSQLMSSQLRegion,
"database": CloudSQLMSSQLDatabase,
"user": CloudSQLMSSQLUser,
"password": CloudSQLMSSQLPass,
"kind": CLOUD_SQL_MSSQL_SOURCE_KIND,
"project": CLOUD_SQL_MSSQL_PROJECT,
"instance": CLOUD_SQL_MSSQL_INSTANCE,
"ipAddress": CLOUD_SQL_MSSQL_IP,
"region": CLOUD_SQL_MSSQL_REGION,
"database": CLOUD_SQL_MSSQL_DATABASE,
"user": CLOUD_SQL_MSSQL_USER,
"password": CLOUD_SQL_MSSQL_PASS,
}
}
// Copied over from cloud_sql_mssql.go
func initCloudSQLMSSQLConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
query := fmt.Sprintf("database=%s&cloudsql=%s:%s:%s", dbname, project, region, instance)
url := &url.URL{
@@ -110,14 +110,14 @@ func initCloudSQLMSSQLConnection(project, region, instance, ipAddress, ipType, u
return db, nil
}
func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
sourceConfig := getCloudSQLMSSQLVars(t)
func TestCloudSQLMssqlToolEndpoints(t *testing.T) {
sourceConfig := getCloudSQLMssqlVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
db, err := initCloudSQLMSSQLConnection(CloudSQLMSSQLProject, CloudSQLMSSQLRegion, CloudSQLMSSQLInstance, CloudSQLMSSQLIp, "public", CloudSQLMSSQLUser, CloudSQLMSSQLPass, CloudSQLMSSQLDatabase)
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
@@ -128,20 +128,20 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetMSSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMssqlParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetMSSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMsSQLTable(t, ctx, db, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMssqlAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMsSQLTable(t, ctx, db, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMSSQLToolKind, toolStatement1, toolStatement2)
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -159,7 +159,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
select1Want, failInvocationWant, createTableStatement := tests.GetMssqlWants()
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
@@ -168,8 +168,8 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
}
// Test connection with different IP type
func TestCloudSQLMSSQLIpConnection(t *testing.T) {
sourceConfig := getCloudSQLMSSQLVars(t)
func TestCloudSQLMssqlIpConnection(t *testing.T) {
sourceConfig := getCloudSQLMssqlVars(t)
tcs := []struct {
name string
@@ -187,7 +187,7 @@ func TestCloudSQLMSSQLIpConnection(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
sourceConfig["ipType"] = tc.ipType
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLMSSQLToolKind)
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MSSQL_TOOL_KIND)
if err != nil {
t.Fatalf("Connection test failure: %s", err)
}

View File

@@ -32,40 +32,40 @@ import (
)
var (
CloudSQLMySQLSourceKind = "cloud-sql-mysql"
CloudSQLMySQLToolKind = "mysql-sql"
CloudSQLMySQLProject = os.Getenv("CLOUD_SQL_MYSQL_PROJECT")
CloudSQLMySQLRegion = os.Getenv("CLOUD_SQL_MYSQL_REGION")
CloudSQLMySQLInstance = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE")
CloudSQLMySQLDatabase = os.Getenv("CLOUD_SQL_MYSQL_DATABASE")
CloudSQLMySQLUser = os.Getenv("CLOUD_SQL_MYSQL_USER")
CloudSQLMySQLPass = os.Getenv("CLOUD_SQL_MYSQL_PASS")
CLOUD_SQL_MYSQL_SOURCE_KIND = "cloud-sql-mysql"
CLOUD_SQL_MYSQL_TOOL_KIND = "mysql-sql"
CLOUD_SQL_MYSQL_PROJECT = os.Getenv("CLOUD_SQL_MYSQL_PROJECT")
CLOUD_SQL_MYSQL_REGION = os.Getenv("CLOUD_SQL_MYSQL_REGION")
CLOUD_SQL_MYSQL_INSTANCE = os.Getenv("CLOUD_SQL_MYSQL_INSTANCE")
CLOUD_SQL_MYSQL_DATABASE = os.Getenv("CLOUD_SQL_MYSQL_DATABASE")
CLOUD_SQL_MYSQL_USER = os.Getenv("CLOUD_SQL_MYSQL_USER")
CLOUD_SQL_MYSQL_PASS = os.Getenv("CLOUD_SQL_MYSQL_PASS")
)
func getCloudSQLMySQLVars(t *testing.T) map[string]any {
switch "" {
case CloudSQLMySQLProject:
case CLOUD_SQL_MYSQL_PROJECT:
t.Fatal("'CLOUD_SQL_MYSQL_PROJECT' not set")
case CloudSQLMySQLRegion:
case CLOUD_SQL_MYSQL_REGION:
t.Fatal("'CLOUD_SQL_MYSQL_REGION' not set")
case CloudSQLMySQLInstance:
case CLOUD_SQL_MYSQL_INSTANCE:
t.Fatal("'CLOUD_SQL_MYSQL_INSTANCE' not set")
case CloudSQLMySQLDatabase:
case CLOUD_SQL_MYSQL_DATABASE:
t.Fatal("'CLOUD_SQL_MYSQL_DATABASE' not set")
case CloudSQLMySQLUser:
case CLOUD_SQL_MYSQL_USER:
t.Fatal("'CLOUD_SQL_MYSQL_USER' not set")
case CloudSQLMySQLPass:
case CLOUD_SQL_MYSQL_PASS:
t.Fatal("'CLOUD_SQL_MYSQL_PASS' not set")
}
return map[string]any{
"kind": CloudSQLMySQLSourceKind,
"project": CloudSQLMySQLProject,
"instance": CloudSQLMySQLInstance,
"region": CloudSQLMySQLRegion,
"database": CloudSQLMySQLDatabase,
"user": CloudSQLMySQLUser,
"password": CloudSQLMySQLPass,
"kind": CLOUD_SQL_MYSQL_SOURCE_KIND,
"project": CLOUD_SQL_MYSQL_PROJECT,
"instance": CLOUD_SQL_MYSQL_INSTANCE,
"region": CLOUD_SQL_MYSQL_REGION,
"database": CLOUD_SQL_MYSQL_DATABASE,
"user": CLOUD_SQL_MYSQL_USER,
"password": CLOUD_SQL_MYSQL_PASS,
}
}
@@ -97,14 +97,14 @@ func initCloudSQLMySQLConnectionPool(project, region, instance, ipType, user, pa
return db, nil
}
func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
func TestCloudSQLMysqlToolEndpoints(t *testing.T) {
sourceConfig := getCloudSQLMySQLVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
pool, err := initCloudSQLMySQLConnectionPool(CloudSQLMySQLProject, CloudSQLMySQLRegion, CloudSQLMySQLInstance, "public", CloudSQLMySQLUser, CloudSQLMySQLPass, CloudSQLMySQLDatabase)
pool, err := initCloudSQLMySQLConnectionPool(CLOUD_SQL_MYSQL_PROJECT, CLOUD_SQL_MYSQL_REGION, CLOUD_SQL_MYSQL_INSTANCE, "public", CLOUD_SQL_MYSQL_USER, CLOUD_SQL_MYSQL_PASS, CLOUD_SQL_MYSQL_DATABASE)
if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
@@ -115,20 +115,20 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetMySQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetMySQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMysqlAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -146,7 +146,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
select1Want, failInvocationWant, createTableStatement := tests.GetMysqlWants()
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
@@ -155,7 +155,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
}
// Test connection with different IP type
func TestCloudSQLMySQLIpConnection(t *testing.T) {
func TestCloudSQLMysqlIpConnection(t *testing.T) {
sourceConfig := getCloudSQLMySQLVars(t)
tcs := []struct {
@@ -174,7 +174,7 @@ func TestCloudSQLMySQLIpConnection(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
sourceConfig["ipType"] = tc.ipType
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLMySQLToolKind)
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_MYSQL_TOOL_KIND)
if err != nil {
t.Fatalf("Connection test failure: %s", err)
}

View File

@@ -31,40 +31,40 @@ import (
)
var (
CloudSQLPostgresSourceKind = "cloud-sql-postgres"
CloudSQLPostgresToolKind = "postgres-sql"
CloudSQLPostgresProject = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT")
CloudSQLPostgresRegion = os.Getenv("CLOUD_SQL_POSTGRES_REGION")
CloudSQLPostgresInstance = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE")
CloudSQLPostgresDatabase = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE")
CloudSQLPostgresUser = os.Getenv("CLOUD_SQL_POSTGRES_USER")
CloudSQLPostgresPass = os.Getenv("CLOUD_SQL_POSTGRES_PASS")
CLOUD_SQL_POSTGRES_SOURCE_KIND = "cloud-sql-postgres"
CLOUD_SQL_POSTGRES_TOOL_KIND = "postgres-sql"
CLOUD_SQL_POSTGRES_PROJECT = os.Getenv("CLOUD_SQL_POSTGRES_PROJECT")
CLOUD_SQL_POSTGRES_REGION = os.Getenv("CLOUD_SQL_POSTGRES_REGION")
CLOUD_SQL_POSTGRES_INSTANCE = os.Getenv("CLOUD_SQL_POSTGRES_INSTANCE")
CLOUD_SQL_POSTGRES_DATABASE = os.Getenv("CLOUD_SQL_POSTGRES_DATABASE")
CLOUD_SQL_POSTGRES_USER = os.Getenv("CLOUD_SQL_POSTGRES_USER")
CLOUD_SQL_POSTGRES_PASS = os.Getenv("CLOUD_SQL_POSTGRES_PASS")
)
func getCloudSQLPgVars(t *testing.T) map[string]any {
switch "" {
case CloudSQLPostgresProject:
case CLOUD_SQL_POSTGRES_PROJECT:
t.Fatal("'CLOUD_SQL_POSTGRES_PROJECT' not set")
case CloudSQLPostgresRegion:
case CLOUD_SQL_POSTGRES_REGION:
t.Fatal("'CLOUD_SQL_POSTGRES_REGION' not set")
case CloudSQLPostgresInstance:
case CLOUD_SQL_POSTGRES_INSTANCE:
t.Fatal("'CLOUD_SQL_POSTGRES_INSTANCE' not set")
case CloudSQLPostgresDatabase:
case CLOUD_SQL_POSTGRES_DATABASE:
t.Fatal("'CLOUD_SQL_POSTGRES_DATABASE' not set")
case CloudSQLPostgresUser:
case CLOUD_SQL_POSTGRES_USER:
t.Fatal("'CLOUD_SQL_POSTGRES_USER' not set")
case CloudSQLPostgresPass:
case CLOUD_SQL_POSTGRES_PASS:
t.Fatal("'CLOUD_SQL_POSTGRES_PASS' not set")
}
return map[string]any{
"kind": CloudSQLPostgresSourceKind,
"project": CloudSQLPostgresProject,
"instance": CloudSQLPostgresInstance,
"region": CloudSQLPostgresRegion,
"database": CloudSQLPostgresDatabase,
"user": CloudSQLPostgresUser,
"password": CloudSQLPostgresPass,
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
"project": CLOUD_SQL_POSTGRES_PROJECT,
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
"region": CLOUD_SQL_POSTGRES_REGION,
"database": CLOUD_SQL_POSTGRES_DATABASE,
"user": CLOUD_SQL_POSTGRES_USER,
"password": CLOUD_SQL_POSTGRES_PASS,
}
}
@@ -108,7 +108,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
var args []string
pool, err := initCloudSQLPgConnectionPool(CloudSQLPostgresProject, CloudSQLPostgresRegion, CloudSQLPostgresInstance, "public", CloudSQLPostgresUser, CloudSQLPostgresPass, CloudSQLPostgresDatabase)
pool, err := initCloudSQLPgConnectionPool(CLOUD_SQL_POSTGRES_PROJECT, CLOUD_SQL_POSTGRES_REGION, CLOUD_SQL_POSTGRES_INSTANCE, "public", CLOUD_SQL_POSTGRES_USER, CLOUD_SQL_POSTGRES_PASS, CLOUD_SQL_POSTGRES_DATABASE)
if err != nil {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
@@ -119,20 +119,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CLOUD_SQL_POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -178,7 +178,7 @@ func TestCloudSQLPgIpConnection(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
sourceConfig["ipType"] = tc.ipType
err := tests.RunSourceConnectionTest(t, sourceConfig, CloudSQLPostgresToolKind)
err := tests.RunSourceConnectionTest(t, sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
if err != nil {
t.Fatalf("Connection test failure: %s", err)
}
@@ -189,32 +189,32 @@ func TestCloudSQLPgIpConnection(t *testing.T) {
func TestCloudSQLPgIAMConnection(t *testing.T) {
getCloudSQLPgVars(t)
// service account email used for IAM should trim the suffix
serviceAccountEmail := strings.TrimSuffix(tests.ServiceAccountEmail, ".gserviceaccount.com")
serviceAccountEmail := strings.TrimSuffix(tests.SERVICE_ACCOUNT_EMAIL, ".gserviceaccount.com")
noPassSourceConfig := map[string]any{
"kind": CloudSQLPostgresSourceKind,
"project": CloudSQLPostgresProject,
"instance": CloudSQLPostgresInstance,
"region": CloudSQLPostgresRegion,
"database": CloudSQLPostgresDatabase,
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
"project": CLOUD_SQL_POSTGRES_PROJECT,
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
"region": CLOUD_SQL_POSTGRES_REGION,
"database": CLOUD_SQL_POSTGRES_DATABASE,
"user": serviceAccountEmail,
}
noUserSourceConfig := map[string]any{
"kind": CloudSQLPostgresSourceKind,
"project": CloudSQLPostgresProject,
"instance": CloudSQLPostgresInstance,
"region": CloudSQLPostgresRegion,
"database": CloudSQLPostgresDatabase,
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
"project": CLOUD_SQL_POSTGRES_PROJECT,
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
"region": CLOUD_SQL_POSTGRES_REGION,
"database": CLOUD_SQL_POSTGRES_DATABASE,
"password": "random",
}
noUserNoPassSourceConfig := map[string]any{
"kind": CloudSQLPostgresSourceKind,
"project": CloudSQLPostgresProject,
"instance": CloudSQLPostgresInstance,
"region": CloudSQLPostgresRegion,
"database": CloudSQLPostgresDatabase,
"kind": CLOUD_SQL_POSTGRES_SOURCE_KIND,
"project": CLOUD_SQL_POSTGRES_PROJECT,
"instance": CLOUD_SQL_POSTGRES_INSTANCE,
"region": CLOUD_SQL_POSTGRES_REGION,
"database": CLOUD_SQL_POSTGRES_DATABASE,
}
tcs := []struct {
name string
@@ -239,7 +239,7 @@ func TestCloudSQLPgIAMConnection(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, CloudSQLPostgresToolKind)
err := tests.RunSourceConnectionTest(t, tc.sourceConfig, CLOUD_SQL_POSTGRES_TOOL_KIND)
if err != nil {
if tc.isErr {
return

View File

@@ -28,7 +28,7 @@ import (
)
// GetToolsConfig returns a mock tools config file
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, authToolStatement string) map[string]any {
func GetToolsConfig(sourceConfig map[string]any, toolKind, param_tool_statement, auth_tool_statement string) map[string]any {
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
@@ -51,7 +51,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, a
"kind": toolKind,
"source": "my-instance",
"description": "Tool to test invocation with params.",
"statement": paramToolStatement,
"statement": param_tool_statement,
"parameters": []any{
map[string]any{
"name": "id",
@@ -70,7 +70,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, a
"source": "my-instance",
"description": "Tool to test authenticated parameters.",
// statement to auto-fill authenticated parameter
"statement": authToolStatement,
"statement": auth_tool_statement,
"parameters": []map[string]any{
{
"name": "email",
@@ -237,8 +237,8 @@ func AddMySqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
return config
}
// AddMSSQLExecuteSqlConfig gets the tools config for `mssql-execute-sql`
func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
// AddMssqlExecuteSqlConfig gets the tools config for `mssql-execute-sql`
func AddMssqlExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
@@ -262,20 +262,20 @@ func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
// GetPostgresSQLParamToolInfo returns statements and param for my-param-tool postgres-sql kind
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES ($1), ($2), ($3);", tableName)
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES ($1), ($2), ($3);", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
params := []any{"Alice", "Jane", "Sid"}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
// GetPostgresSQLAuthToolInfo returns statements and param of my-auth-tool for postgres-sql kind
func GetPostgresSQLAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES ($1, $2), ($3, $4)", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
return createStatement, insertStatement, toolStatement, params
create_statement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT, email TEXT);", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES ($1, $2), ($3, $4)", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName)
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
return create_statement, insert_statement, tool_statement, params
}
// GetPostgresSQLTmplToolStatement returns statements and param for template parameter test cases for postgres-sql kind
@@ -285,51 +285,51 @@ func GetPostgresSQLTmplToolStatement() (string, string) {
return tmplSelectCombined, tmplSelectFilterCombined
}
// GetMSSQLParamToolInfo returns statements and param for my-param-tool mssql-sql kind
func GetMSSQLParamToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (@alice), (@jane), (@sid);", tableName)
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
// GetMssqlParamToolInfo returns statements and param for my-param-tool mssql-sql kind
func GetMssqlParamToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (@alice), (@jane), (@sid);", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid")}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
// GetMSSQLAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
func GetMSSQLAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (@alice, @aliceemail), (@jane, @janeemail);", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", ServiceAccountEmail), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
return createStatement, insertStatement, toolStatement, params
// GetMssqlAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
func GetMssqlAuthToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (@alice, @aliceemail), (@jane, @janeemail);", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email;", tableName)
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", SERVICE_ACCOUNT_EMAIL), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
return create_statement, insert_statement, tool_statement, params
}
// GetMSSQLTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
func GetMSSQLTmplToolStatement() (string, string) {
// GetMssqlTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
func GetMssqlTmplToolStatement() (string, string) {
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = @id"
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = @name"
return tmplSelectCombined, tmplSelectFilterCombined
}
// GetMySQLParamToolInfo returns statements and param for my-param-tool mysql-sql kind
func GetMySQLParamToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
// GetMysqlParamToolInfo returns statements and param for my-param-tool mysql-sql kind
func GetMysqlParamToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
params := []any{"Alice", "Jane", "Sid"}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
// GetMySQLAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
func GetMySQLAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?, ?)", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
return createStatement, insertStatement, toolStatement, params
// GetMysqlAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
func GetMysqlAuthToolInfo(tableName string) (string, string, string, []any) {
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255), email VARCHAR(255));", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?, ?)", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
return create_statement, insert_statement, tool_statement, params
}
// GetMySQLTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
func GetMySQLTmplToolStatement() (string, string) {
// GetMysqlTmplToolStatement returns statements and param for template parameter test cases for mysql-sql kind
func GetMysqlTmplToolStatement() (string, string) {
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ?"
tmplSelectFilterCombined := "SELECT * FROM {{.tableName}} WHERE {{.columnFilter}} = ?"
return tmplSelectCombined, tmplSelectFilterCombined
@@ -349,16 +349,16 @@ func GetPostgresWants() (string, string, string) {
return select1Want, failInvocationWant, createTableStatement
}
// GetMSSQLWants return the expected wants for mssql
func GetMSSQLWants() (string, string, string) {
// GetMssqlWants return the expected wants for mssql
func GetMssqlWants() (string, string, string) {
select1Want := "[{\"\":1}]"
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: mssql: Could not find stored procedure 'SELEC'."}],"isError":true}}`
createTableStatement := `"CREATE TABLE t (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(MAX))"`
return select1Want, failInvocationWant, createTableStatement
}
// GetMySQLWants return the expected wants for mysql
func GetMySQLWants() (string, string, string) {
// GetMysqlWants return the expected wants for mysql
func GetMysqlWants() (string, string, string) {
select1Want := "[{\"1\":1}]"
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}`
createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"`
@@ -367,20 +367,20 @@ func GetMySQLWants() (string, string, string) {
// SetupPostgresSQLTable creates and inserts data into a table of tool
// compatible with postgres-sql tool
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
err := pool.Ping(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
}
// Create table
_, err = pool.Query(ctx, createStatement)
_, err = pool.Query(ctx, create_statement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
_, err = pool.Query(ctx, insertStatement, params...)
_, err = pool.Query(ctx, insert_statement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}
@@ -396,20 +396,20 @@ func SetupPostgresSQLTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool
// SetupMsSQLTable creates and inserts data into a table of tool
// compatible with mssql-sql tool
func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
err := pool.PingContext(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
}
// Create table
_, err = pool.QueryContext(ctx, createStatement)
_, err = pool.QueryContext(ctx, create_statement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
_, err = pool.QueryContext(ctx, insertStatement, params...)
_, err = pool.QueryContext(ctx, insert_statement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}
@@ -425,20 +425,20 @@ func SetupMsSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStat
// SetupMySQLTable creates and inserts data into a table of tool
// compatible with mysql-sql tool
func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStatement, insertStatement, tableName string, params []any) func(*testing.T) {
func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_statement, insert_statement, tableName string, params []any) func(*testing.T) {
err := pool.PingContext(ctx)
if err != nil {
t.Fatalf("unable to connect to test database: %s", err)
}
// Create table
_, err = pool.QueryContext(ctx, createStatement)
_, err = pool.QueryContext(ctx, create_statement)
if err != nil {
t.Fatalf("unable to create test table %s: %s", tableName, err)
}
// Insert test data
_, err = pool.QueryContext(ctx, insertStatement, params...)
_, err = pool.QueryContext(ctx, insert_statement, params...)
if err != nil {
t.Fatalf("unable to insert test data: %s", err)
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package couchbase
package tests
import (
"context"
@@ -34,11 +34,12 @@ const (
)
var (
couchbaseConnection = os.Getenv("COUCHBASE_CONNECTION")
couchbaseBucket = os.Getenv("COUCHBASE_BUCKET")
couchbaseScope = os.Getenv("COUCHBASE_SCOPE")
couchbaseUser = os.Getenv("COUCHBASE_USER")
couchbasePass = os.Getenv("COUCHBASE_PASS")
couchbaseConnection = os.Getenv("COUCHBASE_CONNECTION")
couchbaseBucket = os.Getenv("COUCHBASE_BUCKET")
couchbaseScope = os.Getenv("COUCHBASE_SCOPE")
couchbaseUser = os.Getenv("COUCHBASE_USER")
couchbasePass = os.Getenv("COUCHBASE_PASS")
SERVICE_ACCOUNT_EMAIL = os.Getenv("SERVICE_ACCOUNT_EMAIL")
)
// getCouchbaseVars validates and returns Couchbase configuration variables
@@ -249,7 +250,7 @@ func getCouchbaseAuthToolInfo(collectionName string) (string, []map[string]any)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $email", collectionName)
params := []map[string]any{
{"name": "Alice", "email": tests.ServiceAccountEmail},
{"name": "Alice", "email": SERVICE_ACCOUNT_EMAIL},
{"name": "Jane", "email": "janedoe@gmail.com"},
}
return toolStatement, params

View File

@@ -30,19 +30,20 @@ import (
)
var (
DgraphSourceKind = "dgraph"
DgraphApiKey = "api-key"
DgraphUrl = os.Getenv("DGRAPH_URL")
DGRAPH_SOURCE_KIND = "dgraph"
DGRAPH_TOOL_KIND = "dgraph-dql"
DGRAPH_API_KEY = "api-key"
DGRAPH_URL = os.Getenv("DGRAPH_URL")
)
func getDgraphVars(t *testing.T) map[string]any {
if DgraphUrl == "" {
if DGRAPH_URL == "" {
t.Fatal("'DGRAPH_URL' not set")
}
return map[string]any{
"kind": DgraphSourceKind,
"dgraphUrl": DgraphUrl,
"apiKey": DgraphApiKey,
"kind": DGRAPH_SOURCE_KIND,
"dgraphUrl": DGRAPH_URL,
"apiKey": DGRAPH_API_KEY,
}
}

View File

@@ -33,8 +33,8 @@ import (
)
var (
HttpSourceKind = "http"
HttpToolKind = "http"
HTTP_SOURCE_KIND = "http"
HTTP_TOOL_KIND = "http"
)
func getHTTPSourceConfig(t *testing.T) map[string]any {
@@ -44,7 +44,7 @@ func getHTTPSourceConfig(t *testing.T) map[string]any {
}
idToken = "Bearer " + idToken
return map[string]any{
"kind": HttpSourceKind,
"kind": HTTP_SOURCE_KIND,
"headers": map[string]string{"Authorization": idToken},
}
}
@@ -245,7 +245,7 @@ func TestHttpToolEndpoints(t *testing.T) {
var args []string
toolsFile := getHTTPToolsConfig(sourceConfig, HttpToolKind)
toolsFile := getHTTPToolsConfig(sourceConfig, HTTP_TOOL_KIND)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)

View File

@@ -30,41 +30,41 @@ import (
)
var (
MSSQLSourceKind = "mssql"
MSSQLToolKind = "mssql-sql"
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
MSSQLHost = os.Getenv("MSSQL_HOST")
MSSQLPort = os.Getenv("MSSQL_PORT")
MSSQLUser = os.Getenv("MSSQL_USER")
MSSQLPass = os.Getenv("MSSQL_PASS")
MSSQL_SOURCE_KIND = "mssql"
MSSQL_TOOL_KIND = "mssql-sql"
MSSQL_DATABASE = os.Getenv("MSSQL_DATABASE")
MSSQL_HOST = os.Getenv("MSSQL_HOST")
MSSQL_PORT = os.Getenv("MSSQL_PORT")
MSSQL_USER = os.Getenv("MSSQL_USER")
MSSQL_PASS = os.Getenv("MSSQL_PASS")
)
func getMsSQLVars(t *testing.T) map[string]any {
switch "" {
case MSSQLDatabase:
case MSSQL_DATABASE:
t.Fatal("'MSSQL_DATABASE' not set")
case MSSQLHost:
case MSSQL_HOST:
t.Fatal("'MSSQL_HOST' not set")
case MSSQLPort:
case MSSQL_PORT:
t.Fatal("'MSSQL_PORT' not set")
case MSSQLUser:
case MSSQL_USER:
t.Fatal("'MSSQL_USER' not set")
case MSSQLPass:
case MSSQL_PASS:
t.Fatal("'MSSQL_PASS' not set")
}
return map[string]any{
"kind": MSSQLSourceKind,
"host": MSSQLHost,
"port": MSSQLPort,
"database": MSSQLDatabase,
"user": MSSQLUser,
"password": MSSQLPass,
"kind": MSSQL_SOURCE_KIND,
"host": MSSQL_HOST,
"port": MSSQL_PORT,
"database": MSSQL_DATABASE,
"user": MSSQL_USER,
"password": MSSQL_PASS,
}
}
// Copied over from mssql.go
func initMSSQLConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
func initMssqlConnection(host, port, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
query := url.Values{}
query.Add("database", dbname)
@@ -83,14 +83,14 @@ func initMSSQLConnection(host, port, user, pass, dbname string) (*sql.DB, error)
return db, nil
}
func TestMSSQLToolEndpoints(t *testing.T) {
func TestMssqlToolEndpoints(t *testing.T) {
sourceConfig := getMsSQLVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
pool, err := initMSSQLConnection(MSSQLHost, MSSQLPort, MSSQLUser, MSSQLPass, MSSQLDatabase)
pool, err := initMssqlConnection(MSSQL_HOST, MSSQL_PORT, MSSQL_USER, MSSQL_PASS, MSSQL_DATABASE)
if err != nil {
t.Fatalf("unable to create SQL Server connection pool: %s", err)
}
@@ -101,20 +101,20 @@ func TestMSSQLToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetMSSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMsSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMssqlParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMsSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetMSSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMsSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMssqlAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMsSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQLToolKind, toolStatement1, toolStatement2)
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMssqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMssqlTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -132,7 +132,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
select1Want, failInvocationWant, createTableStatement := tests.GetMssqlWants()
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)

View File

@@ -29,36 +29,36 @@ import (
)
var (
MySQLSourceKind = "mysql"
MySQLToolKind = "mysql-sql"
MySQLDatabase = os.Getenv("MYSQL_DATABASE")
MySQLHost = os.Getenv("MYSQL_HOST")
MySQLPort = os.Getenv("MYSQL_PORT")
MySQLUser = os.Getenv("MYSQL_USER")
MySQLPass = os.Getenv("MYSQL_PASS")
MYSQL_SOURCE_KIND = "mysql"
MYSQL_TOOL_KIND = "mysql-sql"
MYSQL_DATABASE = os.Getenv("MYSQL_DATABASE")
MYSQL_HOST = os.Getenv("MYSQL_HOST")
MYSQL_PORT = os.Getenv("MYSQL_PORT")
MYSQL_USER = os.Getenv("MYSQL_USER")
MYSQL_PASS = os.Getenv("MYSQL_PASS")
)
func getMySQLVars(t *testing.T) map[string]any {
switch "" {
case MySQLDatabase:
case MYSQL_DATABASE:
t.Fatal("'MYSQL_DATABASE' not set")
case MySQLHost:
case MYSQL_HOST:
t.Fatal("'MYSQL_HOST' not set")
case MySQLPort:
case MYSQL_PORT:
t.Fatal("'MYSQL_PORT' not set")
case MySQLUser:
case MYSQL_USER:
t.Fatal("'MYSQL_USER' not set")
case MySQLPass:
case MYSQL_PASS:
t.Fatal("'MYSQL_PASS' not set")
}
return map[string]any{
"kind": MySQLSourceKind,
"host": MySQLHost,
"port": MySQLPort,
"database": MySQLDatabase,
"user": MySQLUser,
"password": MySQLPass,
"kind": MYSQL_SOURCE_KIND,
"host": MYSQL_HOST,
"port": MYSQL_PORT,
"database": MYSQL_DATABASE,
"user": MYSQL_USER,
"password": MYSQL_PASS,
}
}
@@ -74,14 +74,14 @@ func initMySQLConnectionPool(host, port, user, pass, dbname string) (*sql.DB, er
return pool, nil
}
func TestMySQLToolEndpoints(t *testing.T) {
func TestMysqlToolEndpoints(t *testing.T) {
sourceConfig := getMySQLVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
pool, err := initMySQLConnectionPool(MySQLHost, MySQLPort, MySQLUser, MySQLPass, MySQLDatabase)
pool, err := initMySQLConnectionPool(MYSQL_HOST, MYSQL_PORT, MYSQL_USER, MYSQL_PASS, MYSQL_DATABASE)
if err != nil {
t.Fatalf("unable to create MySQL connection pool: %s", err)
}
@@ -92,20 +92,20 @@ func TestMySQLToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetMySQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetMysqlParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetMySQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetMysqlAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, MySQLToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, MYSQL_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMysqlTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MYSQL_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -123,7 +123,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
select1Want, failInvocationWant, createTableStatement := tests.GetMysqlWants()
invokeParamWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
tests.RunToolInvokeTest(t, select1Want, invokeParamWant)
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)

View File

@@ -30,31 +30,32 @@ import (
)
var (
Neo4jSourceKind = "neo4j"
Neo4jDatabase = os.Getenv("NEO4J_DATABASE")
Neo4jUri = os.Getenv("NEO4J_URI")
Neo4jUser = os.Getenv("NEO4J_USER")
Neo4jPass = os.Getenv("NEO4J_PASS")
NEO4J_SOURCE_KIND = "neo4j"
NEO4J_TOOL_KIND = "neo4j-cypher"
NEO4J_DATABASE = os.Getenv("NEO4J_DATABASE")
NEO4J_URI = os.Getenv("NEO4J_URI")
NEO4J_USER = os.Getenv("NEO4J_USER")
NEO4J_PASS = os.Getenv("NEO4J_PASS")
)
func getNeo4jVars(t *testing.T) map[string]any {
switch "" {
case Neo4jDatabase:
case NEO4J_DATABASE:
t.Fatal("'NEO4J_DATABASE' not set")
case Neo4jUri:
case NEO4J_URI:
t.Fatal("'NEO4J_URI' not set")
case Neo4jUser:
case NEO4J_USER:
t.Fatal("'NEO4J_USER' not set")
case Neo4jPass:
case NEO4J_PASS:
t.Fatal("'NEO4J_PASS' not set")
}
return map[string]any{
"kind": Neo4jSourceKind,
"uri": Neo4jUri,
"database": Neo4jDatabase,
"user": Neo4jUser,
"password": Neo4jPass,
"kind": NEO4J_SOURCE_KIND,
"uri": NEO4J_URI,
"database": NEO4J_DATABASE,
"user": NEO4J_USER,
"password": NEO4J_PASS,
}
}

View File

@@ -30,36 +30,36 @@ import (
)
var (
PostgresSourceKind = "postgres"
PostgresToolKind = "postgres-sql"
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
PostgresHost = os.Getenv("POSTGRES_HOST")
PostgresPort = os.Getenv("POSTGRES_PORT")
PostgresUser = os.Getenv("POSTGRES_USER")
PostgresPass = os.Getenv("POSTGRES_PASS")
POSTGRES_SOURCE_KIND = "postgres"
POSTGRES_TOOL_KIND = "postgres-sql"
POSTGRES_DATABASE = os.Getenv("POSTGRES_DATABASE")
POSTGRES_HOST = os.Getenv("POSTGRES_HOST")
POSTGRES_PORT = os.Getenv("POSTGRES_PORT")
POSTGRES_USER = os.Getenv("POSTGRES_USER")
POSTGRES_PASS = os.Getenv("POSTGRES_PASS")
)
func getPostgresVars(t *testing.T) map[string]any {
switch "" {
case PostgresDatabase:
case POSTGRES_DATABASE:
t.Fatal("'POSTGRES_DATABASE' not set")
case PostgresHost:
case POSTGRES_HOST:
t.Fatal("'POSTGRES_HOST' not set")
case PostgresPort:
case POSTGRES_PORT:
t.Fatal("'POSTGRES_PORT' not set")
case PostgresUser:
case POSTGRES_USER:
t.Fatal("'POSTGRES_USER' not set")
case PostgresPass:
case POSTGRES_PASS:
t.Fatal("'POSTGRES_PASS' not set")
}
return map[string]any{
"kind": PostgresSourceKind,
"host": PostgresHost,
"port": PostgresPort,
"database": PostgresDatabase,
"user": PostgresUser,
"password": PostgresPass,
"kind": POSTGRES_SOURCE_KIND,
"host": POSTGRES_HOST,
"port": POSTGRES_PORT,
"database": POSTGRES_DATABASE,
"user": POSTGRES_USER,
"password": POSTGRES_PASS,
}
}
@@ -87,7 +87,7 @@ func TestPostgres(t *testing.T) {
var args []string
pool, err := initPostgresConnectionPool(PostgresHost, PostgresPort, PostgresUser, PostgresPass, PostgresDatabase)
pool, err := initPostgresConnectionPool(POSTGRES_HOST, POSTGRES_PORT, POSTGRES_USER, POSTGRES_PASS, POSTGRES_DATABASE)
if err != nil {
t.Fatalf("unable to create postgres connection pool: %s", err)
}
@@ -98,20 +98,20 @@ func TestPostgres(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := tests.GetPostgresSQLParamToolInfo(tableNameParam)
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement1, insert_statement1, tableNameParam, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := tests.GetPostgresSQLAuthToolInfo(tableNameAuth)
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, create_statement2, insert_statement2, tableNameAuth, params2)
defer teardownTable2(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, POSTGRES_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, POSTGRES_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {

View File

@@ -27,23 +27,23 @@ import (
)
var (
RedisSourceKind = "redis"
RedisToolKind = "redis"
RedisAddress = os.Getenv("REDIS_ADDRESS")
RedisPass = os.Getenv("REDIS_PASS")
REDIS_SOURCE_KIND = "redis"
REDIS_TOOL_KIND = "redis"
REDIS_ADDRESS = os.Getenv("REDIS_ADDRESS")
REDIS_PASS = os.Getenv("REDIS_PASS")
)
func getRedisVars(t *testing.T) map[string]any {
switch "" {
case RedisAddress:
case REDIS_ADDRESS:
t.Fatal("'REDIS_ADDRESS' not set")
case RedisPass:
case REDIS_PASS:
t.Fatal("'REDIS_PASS' not set")
}
return map[string]any{
"kind": RedisSourceKind,
"address": []string{RedisAddress},
"password": RedisPass,
"kind": REDIS_SOURCE_KIND,
"address": []string{REDIS_ADDRESS},
"password": REDIS_PASS,
}
}
@@ -70,7 +70,7 @@ func TestRedisToolEndpoints(t *testing.T) {
var args []string
client, err := initRedisClient(ctx, RedisAddress, RedisPass)
client, err := initRedisClient(ctx, REDIS_ADDRESS, REDIS_PASS)
if err != nil {
t.Fatalf("unable to create Redis connection: %s", err)
}
@@ -80,7 +80,7 @@ func TestRedisToolEndpoints(t *testing.T) {
defer teardownDB(t)
// Write config into a file and pass it to command
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, RedisToolKind)
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, REDIS_TOOL_KIND)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -109,7 +109,7 @@ func setupRedisDB(t *testing.T, ctx context.Context, client *redis.Client) func(
{"HSET", keys[0], "id", 1, "name", "Alice"},
{"HSET", keys[1], "id", 2, "name", "Jane"},
{"HSET", keys[2], "id", 3, "name", "Sid"},
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
{"HSET", tests.SERVICE_ACCOUNT_EMAIL, "name", "Alice"},
}
for _, c := range commands {
resp := client.Do(ctx, c...)

View File

@@ -36,28 +36,28 @@ import (
)
var (
SpannerSourceKind = "spanner"
SpannerToolKind = "spanner-sql"
SpannerProject = os.Getenv("SPANNER_PROJECT")
SpannerDatabase = os.Getenv("SPANNER_DATABASE")
SpannerInstance = os.Getenv("SPANNER_INSTANCE")
SPANNER_SOURCE_KIND = "spanner"
SPANNER_TOOL_KIND = "spanner-sql"
SPANNER_PROJECT = os.Getenv("SPANNER_PROJECT")
SPANNER_DATABASE = os.Getenv("SPANNER_DATABASE")
SPANNER_INSTANCE = os.Getenv("SPANNER_INSTANCE")
)
func getSpannerVars(t *testing.T) map[string]any {
switch "" {
case SpannerProject:
case SPANNER_PROJECT:
t.Fatal("'SPANNER_PROJECT' not set")
case SpannerDatabase:
case SPANNER_DATABASE:
t.Fatal("'SPANNER_DATABASE' not set")
case SpannerInstance:
case SPANNER_INSTANCE:
t.Fatal("'SPANNER_INSTANCE' not set")
}
return map[string]any{
"kind": SpannerSourceKind,
"project": SpannerProject,
"instance": SpannerInstance,
"database": SpannerDatabase,
"kind": SPANNER_SOURCE_KIND,
"project": SPANNER_PROJECT,
"instance": SPANNER_INSTANCE,
"database": SPANNER_DATABASE,
}
}
@@ -96,7 +96,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
var args []string
// Create Spanner client
dataClient, adminClient, err := initSpannerClients(ctx, SpannerProject, SpannerInstance, SpannerDatabase)
dataClient, adminClient, err := initSpannerClients(ctx, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE)
if err != nil {
t.Fatalf("unable to create Spanner client: %s", err)
}
@@ -107,19 +107,19 @@ func TestSpannerToolEndpoints(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := getSpannerParamToolInfo(tableNameParam)
create_statement1, insert_statement1, tool_statement1, params1 := getSpannerParamToolInfo(tableNameParam)
dbString := fmt.Sprintf(
"projects/%s/instances/%s/databases/%s",
SpannerProject,
SpannerInstance,
SpannerDatabase,
SPANNER_PROJECT,
SPANNER_INSTANCE,
SPANNER_DATABASE,
)
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, createStatement1, insertStatement1, tableNameParam, dbString, params1)
teardownTable1 := setupSpannerTable(t, ctx, adminClient, dataClient, create_statement1, insert_statement1, tableNameParam, dbString, params1)
defer teardownTable1(t)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := getSpannerAuthToolInfo(tableNameAuth)
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, createStatement2, insertStatement2, tableNameAuth, dbString, params2)
create_statement2, insert_statement2, tool_statement2, params2 := getSpannerAuthToolInfo(tableNameAuth)
teardownTable2 := setupSpannerTable(t, ctx, adminClient, dataClient, create_statement2, insert_statement2, tableNameAuth, dbString, params2)
defer teardownTable2(t)
// set up data for template param tool
@@ -128,7 +128,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
defer teardownTableTmpl(t)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, SpannerToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, SPANNER_TOOL_KIND, tool_statement1, tool_statement2)
toolsFile = addSpannerExecuteSqlConfig(t, toolsFile)
toolsFile = addSpannerReadOnlyConfig(t, toolsFile)
toolsFile = addTemplateParamConfig(t, toolsFile)
@@ -170,25 +170,25 @@ func TestSpannerToolEndpoints(t *testing.T) {
// getSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
func getSpannerParamToolInfo(tableName string) (string, string, string, map[string]any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3)", tableName)
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3)", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid"}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
// getSpannerAuthToolInfo returns statements and param of my-auth-tool for spanner-sql kind
func getSpannerAuthToolInfo(tableName string) (string, string, string, map[string]any) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
params := map[string]any{
"name1": "Alice",
"email1": tests.ServiceAccountEmail,
"email1": tests.SERVICE_ACCOUNT_EMAIL,
"name2": "Jane",
"email2": "janedoe@gmail.com",
}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
// setupSpannerTable creates and inserts data into a table of tool
@@ -230,7 +230,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.
Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)},
})
if err != nil {
t.Errorf("unable to start drop %s operation: %s", tableName, err)
t.Errorf("unable to start drop table operation: %s", err)
return
}
@@ -352,7 +352,7 @@ func addTemplateParamConfig(t *testing.T, config map[string]any) map[string]any
return config
}
func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam, tableNameAuth string) {
func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select_1_want, invokeParamWant, tableNameParam, tableNameAuth string) {
// Get ID token
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
if err != nil {
@@ -373,7 +373,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool-read-only/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
want: select1Want,
want: select_1_want,
isErr: false,
},
{
@@ -417,7 +417,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
want: select1Want,
want: select_1_want,
isErr: false,
},
{
@@ -455,7 +455,7 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
isErr: false,
want: select1Want,
want: select_1_want,
},
{
name: "Invoke my-auth-exec-sql-tool with invalid auth token",

View File

@@ -29,15 +29,15 @@ import (
)
var (
SQLiteSourceKind = "sqlite"
SQLiteToolKind = "sqlite-sql"
SQLiteDatabase = os.Getenv("SQLITE_DATABASE")
SQLITE_SOURCE_KIND = "sqlite"
SQLITE_TOOL_KIND = "sqlite-sql"
SQLITE_DATABASE = os.Getenv("SQLITE_DATABASE")
)
func getSQLiteVars(t *testing.T) map[string]any {
return map[string]any{
"kind": SQLiteSourceKind,
"database": SQLiteDatabase,
"kind": SQLITE_SOURCE_KIND,
"database": SQLITE_DATABASE,
}
}
@@ -81,19 +81,19 @@ func setupSQLiteTestDB(t *testing.T, ctx context.Context, db *sql.DB, createStat
}
func getSQLiteParamToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT);", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
create_statement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT);", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?);", tableName)
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
params := []any{"Alice", "Jane", "Sid"}
return createStatement, insertStatement, toolStatement, params
return create_statement, insert_statement, tool_statement, params
}
func getSQLiteAuthToolInfo(tableName string) (string, string, string, []any) {
createStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT)", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?,?) RETURNING id, name, email;", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
params := []any{"Alice", tests.ServiceAccountEmail, "Jane", "janedoe@gmail.com"}
return createStatement, insertStatement, toolStatement, params
create_statement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT)", tableName)
insert_statement := fmt.Sprintf("INSERT INTO %s (name, email) VALUES (?, ?), (?,?) RETURNING id, name, email;", tableName)
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = ?;", tableName)
params := []any{"Alice", tests.SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
return create_statement, insert_statement, tool_statement, params
}
func getSQLiteTmplToolStatement() (string, string) {
@@ -103,7 +103,7 @@ func getSQLiteTmplToolStatement() (string, string) {
}
func TestSQLiteToolEndpoint(t *testing.T) {
db, teardownDb, sqliteDb, err := initSQLiteDb(t, SQLiteDatabase)
db, teardownDb, sqliteDb, err := initSQLiteDb(t, SQLITE_DATABASE)
if err != nil {
t.Fatal(err)
}
@@ -123,17 +123,17 @@ func TestSQLiteToolEndpoint(t *testing.T) {
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
// set up data for param tool
createStatement1, insertStatement1, toolStatement1, params1 := getSQLiteParamToolInfo(tableNameParam)
setupSQLiteTestDB(t, ctx, db, createStatement1, insertStatement1, tableNameParam, params1)
create_statement1, insert_statement1, tool_statement1, params1 := getSQLiteParamToolInfo(tableNameParam)
setupSQLiteTestDB(t, ctx, db, create_statement1, insert_statement1, tableNameParam, params1)
// set up data for auth tool
createStatement2, insertStatement2, toolStatement2, params2 := getSQLiteAuthToolInfo(tableNameAuth)
setupSQLiteTestDB(t, ctx, db, createStatement2, insertStatement2, tableNameAuth, params2)
create_statement2, insert_statement2, tool_statement2, params2 := getSQLiteAuthToolInfo(tableNameAuth)
setupSQLiteTestDB(t, ctx, db, create_statement2, insert_statement2, tableNameAuth, params2)
// Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, SQLiteToolKind, toolStatement1, toolStatement2)
toolsFile := tests.GetToolsConfig(sourceConfig, SQLITE_TOOL_KIND, tool_statement1, tool_statement2)
tmplSelectCombined, tmplSelectFilterCombined := getSQLiteTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLiteToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLITE_TOOL_KIND, tmplSelectCombined, tmplSelectFilterCombined, "")
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {

View File

@@ -109,15 +109,15 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant string) {
isErr: false,
},
{
name: "Invoke my-param-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
name: "Invoke my-tool without parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: true,
},
{
name: "Invoke my-param-tool with insufficient parameters",
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
name: "Invoke my-tool with insufficient parameters",
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
isErr: true,
@@ -428,7 +428,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
}
}
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select1Want string) {
func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, select_1_want string) {
// Get ID token
idToken, err := GetGoogleIdToken(ClientId)
if err != nil {
@@ -449,7 +449,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, sele
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
want: select1Want,
want: select_1_want,
isErr: false,
},
{
@@ -489,7 +489,7 @@ func RunExecuteSqlToolInvokeTest(t *testing.T, createTableStatement string, sele
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{"sql":"SELECT 1"}`)),
isErr: false,
want: select1Want,
want: select_1_want,
},
{
name: "Invoke my-auth-exec-sql-tool with invalid auth token",
@@ -597,7 +597,7 @@ func RunInitialize(t *testing.T, protocolVersion string) string {
}
// RunMCPToolCallMethod runs the tool/call for mcp endpoint
func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant string) {
func RunMCPToolCallMethod(t *testing.T, invokeParamWant, fail_invocation_want string) {
sessionId := RunInitialize(t, "2024-11-05")
header := map[string]string{}
if sessionId != "" {
@@ -715,7 +715,7 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
"arguments": map[string]any{"id": 1},
},
},
want: failInvocationWant,
want: fail_invocation_want,
},
}
for _, tc := range invokeTcs {

View File

@@ -27,19 +27,19 @@ import (
)
var (
ValkeySourceKind = "valkey"
ValkeyToolKind = "valkey"
ValkeyAddress = os.Getenv("VALKEY_ADDRESS")
VALKEY_SOURCE_KIND = "valkey"
VALKEY_TOOL_KIND = "valkey"
VALKEY_ADDRESS = os.Getenv("VALKEY_ADDRESS")
)
func getValkeyVars(t *testing.T) map[string]any {
switch "" {
case ValkeyAddress:
case VALKEY_ADDRESS:
t.Fatal("'VALKEY_ADDRESS' not set")
}
return map[string]any{
"kind": ValkeySourceKind,
"address": []string{ValkeyAddress},
"kind": VALKEY_SOURCE_KIND,
"address": []string{VALKEY_ADDRESS},
"disableCache": true,
}
}
@@ -73,7 +73,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
var args []string
client, err := initValkeyClient(ctx, []string{ValkeyAddress})
client, err := initValkeyClient(ctx, []string{VALKEY_ADDRESS})
if err != nil {
t.Fatalf("unable to create Valkey connection: %s", err)
}
@@ -83,7 +83,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
defer teardownDB(t)
// Write config into a file and pass it to command
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, ValkeyToolKind)
toolsFile := tests.GetRedisValkeyToolsConfig(sourceConfig, VALKEY_TOOL_KIND)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
@@ -112,7 +112,7 @@ func setupValkeyDB(t *testing.T, ctx context.Context, client valkey.Client) func
{"HSET", keys[0], "name", "Alice", "id", "1"},
{"HSET", keys[1], "name", "Jane", "id", "2"},
{"HSET", keys[2], "name", "Sid", "id", "3"},
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
{"HSET", tests.SERVICE_ACCOUNT_EMAIL, "name", "Alice"},
}
builtCmds := make(valkey.Commands, len(commands))