mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-13 16:45:01 -05:00
Compare commits
31 Commits
adk-python
...
pr/dumians
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
251ef22839 | ||
|
|
8dc4bd7dd6 | ||
|
|
a00d0edcd5 | ||
|
|
d1eb1799a0 | ||
|
|
2b81d6099a | ||
|
|
53865b6e21 | ||
|
|
6843b46328 | ||
|
|
05152f732d | ||
|
|
a48101b3c5 | ||
|
|
418d6d791e | ||
|
|
452d686750 | ||
|
|
8fb74263a7 | ||
|
|
09f3bc7959 | ||
|
|
7970b8787e | ||
|
|
0b7a86ae58 | ||
|
|
0797142103 | ||
|
|
f26750e834 | ||
|
|
c46d7d6fa0 | ||
|
|
5e2034d146 | ||
|
|
e2272ccdbc | ||
|
|
97f68129f5 | ||
|
|
fea96fed03 | ||
|
|
195767bdcd | ||
|
|
c5524d32f5 | ||
|
|
e1739abd81 | ||
|
|
478a0bdb59 | ||
|
|
2d341acaa6 | ||
|
|
f032389a07 | ||
|
|
32610d71a3 | ||
|
|
32cb4db712 | ||
|
|
1fdd99a9b6 |
@@ -354,6 +354,30 @@ steps:
|
|||||||
postgressql \
|
postgressql \
|
||||||
postgresexecutesql
|
postgresexecutesql
|
||||||
|
|
||||||
|
- id: "cockroachdb"
|
||||||
|
name: golang:1
|
||||||
|
waitFor: ["compile-test-binary"]
|
||||||
|
entrypoint: /bin/bash
|
||||||
|
env:
|
||||||
|
- "GOPATH=/gopath"
|
||||||
|
- "COCKROACHDB_DATABASE=$_DATABASE_NAME"
|
||||||
|
- "COCKROACHDB_PORT=$_COCKROACHDB_PORT"
|
||||||
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
|
secretEnv: ["COCKROACHDB_USER", "COCKROACHDB_HOST","CLIENT_ID"]
|
||||||
|
volumes:
|
||||||
|
- name: "go"
|
||||||
|
path: "/gopath"
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
.ci/test_with_coverage.sh \
|
||||||
|
"CockroachDB" \
|
||||||
|
cockroachdb \
|
||||||
|
cockroachdbsql \
|
||||||
|
cockroachdbexecutesql \
|
||||||
|
cockroachdblisttables \
|
||||||
|
cockroachdblistschemas
|
||||||
|
|
||||||
- id: "spanner"
|
- id: "spanner"
|
||||||
name: golang:1
|
name: golang:1
|
||||||
waitFor: ["compile-test-binary"]
|
waitFor: ["compile-test-binary"]
|
||||||
@@ -919,7 +943,7 @@ steps:
|
|||||||
# Install the C compiler and Oracle SDK headers needed for cgo
|
# Install the C compiler and Oracle SDK headers needed for cgo
|
||||||
dnf install -y gcc oracle-instantclient-devel
|
dnf install -y gcc oracle-instantclient-devel
|
||||||
# Install Go
|
# Install Go
|
||||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
|
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.5.linux-amd64.tar.gz"
|
||||||
tar -C /usr/local -xzf go.tar.gz
|
tar -C /usr/local -xzf go.tar.gz
|
||||||
export PATH="/usr/local/go/bin:$$PATH"
|
export PATH="/usr/local/go/bin:$$PATH"
|
||||||
|
|
||||||
@@ -1129,6 +1153,11 @@ availableSecrets:
|
|||||||
env: MARIADB_HOST
|
env: MARIADB_HOST
|
||||||
- versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest
|
||||||
env: MONGODB_URI
|
env: MONGODB_URI
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/cockroachdb_user/versions/latest
|
||||||
|
env: COCKROACHDB_USER
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/cockroachdb_host/versions/latest
|
||||||
|
env: COCKROACHDB_HOST
|
||||||
|
|
||||||
|
|
||||||
options:
|
options:
|
||||||
logging: CLOUD_LOGGING_ONLY
|
logging: CLOUD_LOGGING_ONLY
|
||||||
@@ -1189,6 +1218,9 @@ substitutions:
|
|||||||
_SINGLESTORE_PORT: "3308"
|
_SINGLESTORE_PORT: "3308"
|
||||||
_SINGLESTORE_DATABASE: "singlestore"
|
_SINGLESTORE_DATABASE: "singlestore"
|
||||||
_SINGLESTORE_USER: "root"
|
_SINGLESTORE_USER: "root"
|
||||||
|
_COCKROACHDB_HOST: 127.0.0.1
|
||||||
|
_COCKROACHDB_PORT: "26257"
|
||||||
|
_COCKROACHDB_USER: "root"
|
||||||
_MARIADB_PORT: "3307"
|
_MARIADB_PORT: "3307"
|
||||||
_MARIADB_DATABASE: test_database
|
_MARIADB_DATABASE: test_database
|
||||||
_SNOWFLAKE_DATABASE: "test"
|
_SNOWFLAKE_DATABASE: "test"
|
||||||
|
|||||||
1
.github/release-please.yml
vendored
1
.github/release-please.yml
vendored
@@ -37,6 +37,7 @@ extraFiles: [
|
|||||||
"docs/en/how-to/connect-ide/postgres_mcp.md",
|
"docs/en/how-to/connect-ide/postgres_mcp.md",
|
||||||
"docs/en/how-to/connect-ide/neo4j_mcp.md",
|
"docs/en/how-to/connect-ide/neo4j_mcp.md",
|
||||||
"docs/en/how-to/connect-ide/sqlite_mcp.md",
|
"docs/en/how-to/connect-ide/sqlite_mcp.md",
|
||||||
|
"docs/en/how-to/connect-ide/oracle_mcp.md",
|
||||||
"gemini-extension.json",
|
"gemini-extension.json",
|
||||||
{
|
{
|
||||||
"type": "json",
|
"type": "json",
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
|
|||||||
# Add a new version block here before every release
|
# Add a new version block here before every release
|
||||||
# The order of versions in this file is mirrored into the dropdown
|
# The order of versions in this file is mirrored into the dropdown
|
||||||
|
|
||||||
|
[[params.versions]]
|
||||||
|
version = "v0.27.0"
|
||||||
|
url = "https://googleapis.github.io/genai-toolbox/v0.27.0/"
|
||||||
|
|
||||||
[[params.versions]]
|
[[params.versions]]
|
||||||
version = "v0.26.0"
|
version = "v0.26.0"
|
||||||
url = "https://googleapis.github.io/genai-toolbox/v0.26.0/"
|
url = "https://googleapis.github.io/genai-toolbox/v0.26.0/"
|
||||||
|
|||||||
26
CHANGELOG.md
26
CHANGELOG.md
@@ -1,5 +1,31 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.27.0](https://github.com/googleapis/genai-toolbox/compare/v0.26.0...v0.27.0) (2026-02-12)
|
||||||
|
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* Update configuration file v2 ([#2369](https://github.com/googleapis/genai-toolbox/issues/2369))([293c1d6](https://github.com/googleapis/genai-toolbox/commit/293c1d6889c39807855ba5e01d4c13ba2a4c50ce))
|
||||||
|
* Update/add detailed telemetry for mcp endpoint compliant with OTEL semantic convention ([#1987](https://github.com/googleapis/genai-toolbox/issues/1987)) ([478a0bd](https://github.com/googleapis/genai-toolbox/commit/478a0bdb59288c1213f83862f95a698b4c2c0aab))
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **cli/invoke:** Add support for direct tool invocation from CLI ([#2353](https://github.com/googleapis/genai-toolbox/issues/2353)) ([6e49ba4](https://github.com/googleapis/genai-toolbox/commit/6e49ba436ef2390c13feaf902b29f5907acffb57))
|
||||||
|
* **cli/skills:** Add support for generating agent skills from toolset ([#2392](https://github.com/googleapis/genai-toolbox/issues/2392)) ([80ef346](https://github.com/googleapis/genai-toolbox/commit/80ef34621453b77bdf6a6016c354f102a17ada04))
|
||||||
|
* **cloud-logging-admin:** Add source, tools, integration test and docs ([#2137](https://github.com/googleapis/genai-toolbox/issues/2137)) ([252fc30](https://github.com/googleapis/genai-toolbox/commit/252fc3091af10d25d8d7af7e047b5ac87a5dd041))
|
||||||
|
* **cockroachdb:** Add CockroachDB integration with cockroach-go ([#2006](https://github.com/googleapis/genai-toolbox/issues/2006)) ([1fdd99a](https://github.com/googleapis/genai-toolbox/commit/1fdd99a9b609a5e906acce414226ff44d75d5975))
|
||||||
|
* **prebuiltconfigs/alloydb-omni:** Implement Alloydb omni dataplane tools ([#2340](https://github.com/googleapis/genai-toolbox/issues/2340)) ([e995349](https://github.com/googleapis/genai-toolbox/commit/e995349ea0756c700d188b8f04e9459121219f0c))
|
||||||
|
* **server:** Add Tool call error categories ([#2387](https://github.com/googleapis/genai-toolbox/issues/2387)) ([32cb4db](https://github.com/googleapis/genai-toolbox/commit/32cb4db712d27579c1bf29e61cbd0bed02286c28))
|
||||||
|
* **tools/looker:** support `looker-validate-project` tool ([#2430](https://github.com/googleapis/genai-toolbox/issues/2430)) ([a15a128](https://github.com/googleapis/genai-toolbox/commit/a15a12873f936b0102aeb9500cc3bcd71bb38c34))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **dataplex:** Capture GCP HTTP errors in MCP Toolbox ([#2347](https://github.com/googleapis/genai-toolbox/issues/2347)) ([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15))
|
||||||
|
* **sources/cockroachdb:** Update kind to type ([#2465](https://github.com/googleapis/genai-toolbox/issues/2465)) ([2d341ac](https://github.com/googleapis/genai-toolbox/commit/2d341acaa61c3c1fe908fceee8afbd90fb646d3a))
|
||||||
|
* Surface Dataplex API errors in MCP results ([#2347](https://github.com/googleapis/genai-toolbox/pull/2347))([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15))
|
||||||
|
|
||||||
## [0.26.0](https://github.com/googleapis/genai-toolbox/compare/v0.25.0...v0.26.0) (2026-01-22)
|
## [0.26.0](https://github.com/googleapis/genai-toolbox/compare/v0.25.0...v0.26.0) (2026-01-22)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -142,7 +142,7 @@ To install Toolbox as a binary:
|
|||||||
>
|
>
|
||||||
> ```sh
|
> ```sh
|
||||||
> # see releases page for other versions
|
> # see releases page for other versions
|
||||||
> export VERSION=0.26.0
|
> export VERSION=0.27.0
|
||||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||||
> chmod +x toolbox
|
> chmod +x toolbox
|
||||||
> ```
|
> ```
|
||||||
@@ -155,7 +155,7 @@ To install Toolbox as a binary:
|
|||||||
>
|
>
|
||||||
> ```sh
|
> ```sh
|
||||||
> # see releases page for other versions
|
> # see releases page for other versions
|
||||||
> export VERSION=0.26.0
|
> export VERSION=0.27.0
|
||||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||||
> chmod +x toolbox
|
> chmod +x toolbox
|
||||||
> ```
|
> ```
|
||||||
@@ -168,7 +168,7 @@ To install Toolbox as a binary:
|
|||||||
>
|
>
|
||||||
> ```sh
|
> ```sh
|
||||||
> # see releases page for other versions
|
> # see releases page for other versions
|
||||||
> export VERSION=0.26.0
|
> export VERSION=0.27.0
|
||||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||||
> chmod +x toolbox
|
> chmod +x toolbox
|
||||||
> ```
|
> ```
|
||||||
@@ -181,7 +181,7 @@ To install Toolbox as a binary:
|
|||||||
>
|
>
|
||||||
> ```cmd
|
> ```cmd
|
||||||
> :: see releases page for other versions
|
> :: see releases page for other versions
|
||||||
> set VERSION=0.26.0
|
> set VERSION=0.27.0
|
||||||
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||||
> ```
|
> ```
|
||||||
>
|
>
|
||||||
@@ -193,7 +193,7 @@ To install Toolbox as a binary:
|
|||||||
>
|
>
|
||||||
> ```powershell
|
> ```powershell
|
||||||
> # see releases page for other versions
|
> # see releases page for other versions
|
||||||
> $VERSION = "0.26.0"
|
> $VERSION = "0.27.0"
|
||||||
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||||
> ```
|
> ```
|
||||||
>
|
>
|
||||||
@@ -206,7 +206,7 @@ You can also install Toolbox as a container:
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
export VERSION=0.26.0
|
export VERSION=0.27.0
|
||||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ To install from source, ensure you have the latest version of
|
|||||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go install github.com/googleapis/genai-toolbox@v0.26.0
|
go install github.com/googleapis/genai-toolbox@v0.27.0
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
257
cmd/internal/imports.go
Normal file
257
cmd/internal/imports.go
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
// Import prompt packages for side effect of registration
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/prompts/custom"
|
||||||
|
|
||||||
|
// Import tool packages for side effect of registration
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/valkey"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql"
|
||||||
|
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/singlestore"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/snowflake"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/trino"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/valkey"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb"
|
||||||
|
)
|
||||||
@@ -18,37 +18,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RootCommand defines the interface for required by invoke subcommand.
|
func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
|
||||||
// This allows subcommands to access shared resources and functionality without
|
|
||||||
// direct coupling to the root command's implementation.
|
|
||||||
type RootCommand interface {
|
|
||||||
// Config returns a copy of the current server configuration.
|
|
||||||
Config() server.ServerConfig
|
|
||||||
|
|
||||||
// Out returns the writer used for standard output.
|
|
||||||
Out() io.Writer
|
|
||||||
|
|
||||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
|
||||||
LoadConfig(ctx context.Context) error
|
|
||||||
|
|
||||||
// Setup initializes the runtime environment, including logging and telemetry.
|
|
||||||
// It returns the updated context and a shutdown function to be called when finished.
|
|
||||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
|
||||||
|
|
||||||
// Logger returns the logger instance.
|
|
||||||
Logger() log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
|
||||||
cmd := &cobra.Command{
|
cmd := &cobra.Command{
|
||||||
Use: "invoke <tool-name> [params]",
|
Use: "invoke <tool-name> [params]",
|
||||||
Short: "Execute a tool directly",
|
Short: "Execute a tool directly",
|
||||||
@@ -58,17 +36,17 @@ Example:
|
|||||||
toolbox invoke my-tool '{"param1": "value1"}'`,
|
toolbox invoke my-tool '{"param1": "value1"}'`,
|
||||||
Args: cobra.MinimumNArgs(1),
|
Args: cobra.MinimumNArgs(1),
|
||||||
RunE: func(c *cobra.Command, args []string) error {
|
RunE: func(c *cobra.Command, args []string) error {
|
||||||
return runInvoke(c, args, rootCmd)
|
return runInvoke(c, args, opts)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
func runInvoke(cmd *cobra.Command, args []string, opts *internal.ToolboxOptions) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ctx, shutdown, err := rootCmd.Setup(ctx)
|
ctx, shutdown, err := opts.Setup(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -76,16 +54,16 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
_ = shutdown(ctx)
|
_ = shutdown(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Load and merge tool configurations
|
_, err = opts.LoadConfig(ctx)
|
||||||
if err := rootCmd.LoadConfig(ctx); err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize Resources
|
// Initialize Resources
|
||||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, rootCmd.Config())
|
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("failed to initialize resources: %w", err)
|
errMsg := fmt.Errorf("failed to initialize resources: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,7 +74,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
tool, ok := resourceMgr.GetTool(toolName)
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
errMsg := fmt.Errorf("tool %q not found", toolName)
|
errMsg := fmt.Errorf("tool %q not found", toolName)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,7 +87,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
if paramsInput != "" {
|
if paramsInput != "" {
|
||||||
if err := json.Unmarshal([]byte(paramsInput), ¶ms); err != nil {
|
if err := json.Unmarshal([]byte(paramsInput), ¶ms); err != nil {
|
||||||
errMsg := fmt.Errorf("params must be a valid JSON string: %w", err)
|
errMsg := fmt.Errorf("params must be a valid JSON string: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,14 +95,14 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
parsedParams, err := parameters.ParseParams(tool.GetParameters(), params, nil)
|
parsedParams, err := parameters.ParseParams(tool.GetParameters(), params, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("invalid parameters: %w", err)
|
errMsg := fmt.Errorf("invalid parameters: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedParams, err = tool.EmbedParams(ctx, parsedParams, resourceMgr.GetEmbeddingModelMap())
|
parsedParams, err = tool.EmbedParams(ctx, parsedParams, resourceMgr.GetEmbeddingModelMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("error embedding parameters: %w", err)
|
errMsg := fmt.Errorf("error embedding parameters: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,19 +110,19 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
requiresAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
requiresAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("failed to check auth requirements: %w", err)
|
errMsg := fmt.Errorf("failed to check auth requirements: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
if requiresAuth {
|
if requiresAuth {
|
||||||
errMsg := fmt.Errorf("client authorization is not supported")
|
errMsg := fmt.Errorf("client authorization is not supported")
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := tool.Invoke(ctx, resourceMgr, parsedParams, "")
|
result, err := tool.Invoke(ctx, resourceMgr, parsedParams, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("tool execution failed: %w", err)
|
errMsg := fmt.Errorf("tool execution failed: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,10 +130,10 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
|||||||
output, err := json.MarshalIndent(result, "", " ")
|
output, err := json.MarshalIndent(result, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("failed to marshal result: %w", err)
|
errMsg := fmt.Errorf("failed to marshal result: %w", err)
|
||||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
fmt.Fprintln(rootCmd.Out(), string(output))
|
fmt.Fprintln(opts.IOStreams.Out, string(output))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -12,16 +12,38 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package cmd
|
package invoke
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"bytes"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func invokeCommand(args []string) (string, error) {
|
||||||
|
parentCmd := &cobra.Command{Use: "toolbox"}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf))
|
||||||
|
internal.PersistentFlags(parentCmd, opts)
|
||||||
|
|
||||||
|
cmd := NewCommand(opts)
|
||||||
|
parentCmd.AddCommand(cmd)
|
||||||
|
parentCmd.SetArgs(args)
|
||||||
|
|
||||||
|
err := parentCmd.Execute()
|
||||||
|
return buf.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvokeTool(t *testing.T) {
|
func TestInvokeTool(t *testing.T) {
|
||||||
// Create a temporary tools file
|
// Create a temporary tools file
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -86,7 +108,7 @@ tools:
|
|||||||
|
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
_, got, err := invokeCommandWithContext(context.Background(), tc.args)
|
got, err := invokeCommand(tc.args)
|
||||||
if (err != nil) != tc.wantErr {
|
if (err != nil) != tc.wantErr {
|
||||||
t.Fatalf("got error %v, wantErr %v", err, tc.wantErr)
|
t.Fatalf("got error %v, wantErr %v", err, tc.wantErr)
|
||||||
}
|
}
|
||||||
@@ -121,7 +143,7 @@ tools:
|
|||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"invoke", "bq-tool", "--tools-file", toolsFilePath}
|
args := []string{"invoke", "bq-tool", "--tools-file", toolsFilePath}
|
||||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
_, err := invokeCommand(args)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for tool requiring client auth, but got nil")
|
t.Fatal("expected error for tool requiring client auth, but got nil")
|
||||||
}
|
}
|
||||||
251
cmd/internal/options.go
Normal file
251
cmd/internal/options.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/log"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IOStreams struct {
|
||||||
|
In io.Reader
|
||||||
|
Out io.Writer
|
||||||
|
ErrOut io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolboxOptions holds dependencies shared by all commands.
|
||||||
|
type ToolboxOptions struct {
|
||||||
|
IOStreams IOStreams
|
||||||
|
Logger log.Logger
|
||||||
|
Cfg server.ServerConfig
|
||||||
|
ToolsFile string
|
||||||
|
ToolsFiles []string
|
||||||
|
ToolsFolder string
|
||||||
|
PrebuiltConfigs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option defines a function that modifies the ToolboxOptions struct.
|
||||||
|
type Option func(*ToolboxOptions)
|
||||||
|
|
||||||
|
// NewToolboxOptions creates a new instance with defaults, then applies any
|
||||||
|
// provided options.
|
||||||
|
func NewToolboxOptions(opts ...Option) *ToolboxOptions {
|
||||||
|
o := &ToolboxOptions{
|
||||||
|
IOStreams: IOStreams{
|
||||||
|
In: os.Stdin,
|
||||||
|
Out: os.Stdout,
|
||||||
|
ErrOut: os.Stderr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply allows you to update an EXISTING ToolboxOptions instance.
|
||||||
|
// This is useful for "late binding".
|
||||||
|
func (o *ToolboxOptions) Apply(opts ...Option) {
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIOStreams updates the IO streams.
|
||||||
|
func WithIOStreams(out, err io.Writer) Option {
|
||||||
|
return func(o *ToolboxOptions) {
|
||||||
|
o.IOStreams.Out = out
|
||||||
|
o.IOStreams.ErrOut = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup create logger and telemetry instrumentations.
|
||||||
|
func (opts *ToolboxOptions) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) {
|
||||||
|
// If stdio, set logger's out stream (usually DEBUG and INFO logs) to
|
||||||
|
// errStream
|
||||||
|
loggerOut := opts.IOStreams.Out
|
||||||
|
if opts.Cfg.Stdio {
|
||||||
|
loggerOut = opts.IOStreams.ErrOut
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle logger separately from config
|
||||||
|
logger, err := log.NewLogger(opts.Cfg.LoggingFormat.String(), opts.Cfg.LogLevel.String(), loggerOut, opts.IOStreams.ErrOut)
|
||||||
|
if err != nil {
|
||||||
|
return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = util.WithLogger(ctx, logger)
|
||||||
|
opts.Logger = logger
|
||||||
|
|
||||||
|
// Set up OpenTelemetry
|
||||||
|
otelShutdown, err := telemetry.SetupOTel(ctx, opts.Cfg.Version, opts.Cfg.TelemetryOTLP, opts.Cfg.TelemetryGCP, opts.Cfg.TelemetryServiceName)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return ctx, nil, errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownFunc := func(ctx context.Context) error {
|
||||||
|
err := otelShutdown(ctx)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
instrumentation, err := telemetry.CreateTelemetryInstrumentation(opts.Cfg.Version)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return ctx, shutdownFunc, errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||||
|
|
||||||
|
return ctx, shutdownFunc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig checks and merge files that should be loaded into the server
|
||||||
|
func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
|
||||||
|
// Determine if Custom Files should be loaded
|
||||||
|
// Check for explicit custom flags
|
||||||
|
isCustomConfigured := opts.ToolsFile != "" || len(opts.ToolsFiles) > 0 || opts.ToolsFolder != ""
|
||||||
|
|
||||||
|
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
|
||||||
|
useDefaultToolsFile := len(opts.PrebuiltConfigs) == 0 && !isCustomConfigured
|
||||||
|
|
||||||
|
if useDefaultToolsFile {
|
||||||
|
opts.ToolsFile = "tools.yaml"
|
||||||
|
isCustomConfigured = true
|
||||||
|
}
|
||||||
|
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return isCustomConfigured, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var allToolsFiles []ToolsFile
|
||||||
|
|
||||||
|
// Load Prebuilt Configuration
|
||||||
|
|
||||||
|
if len(opts.PrebuiltConfigs) > 0 {
|
||||||
|
slices.Sort(opts.PrebuiltConfigs)
|
||||||
|
sourcesList := strings.Join(opts.PrebuiltConfigs, ", ")
|
||||||
|
logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList)
|
||||||
|
logger.InfoContext(ctx, logMsg)
|
||||||
|
|
||||||
|
for _, configName := range opts.PrebuiltConfigs {
|
||||||
|
buf, err := prebuiltconfigs.Get(configName)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorContext(ctx, err.Error())
|
||||||
|
return isCustomConfigured, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse into ToolsFile struct
|
||||||
|
parsed, err := parseToolsFile(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return isCustomConfigured, errMsg
|
||||||
|
}
|
||||||
|
allToolsFiles = append(allToolsFiles, parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load Custom Configurations
|
||||||
|
if isCustomConfigured {
|
||||||
|
// Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder)
|
||||||
|
if (opts.ToolsFile != "" && len(opts.ToolsFiles) > 0) ||
|
||||||
|
(opts.ToolsFile != "" && opts.ToolsFolder != "") ||
|
||||||
|
(len(opts.ToolsFiles) > 0 && opts.ToolsFolder != "") {
|
||||||
|
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return isCustomConfigured, errMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
var customTools ToolsFile
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if len(opts.ToolsFiles) > 0 {
|
||||||
|
// Use tools-files
|
||||||
|
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(opts.ToolsFiles)))
|
||||||
|
customTools, err = LoadAndMergeToolsFiles(ctx, opts.ToolsFiles)
|
||||||
|
} else if opts.ToolsFolder != "" {
|
||||||
|
// Use tools-folder
|
||||||
|
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", opts.ToolsFolder))
|
||||||
|
customTools, err = LoadAndMergeToolsFolder(ctx, opts.ToolsFolder)
|
||||||
|
} else {
|
||||||
|
// Use single file (tools-file or default `tools.yaml`)
|
||||||
|
buf, readFileErr := os.ReadFile(opts.ToolsFile)
|
||||||
|
if readFileErr != nil {
|
||||||
|
errMsg := fmt.Errorf("unable to read tool file at %q: %w", opts.ToolsFile, readFileErr)
|
||||||
|
logger.ErrorContext(ctx, errMsg.Error())
|
||||||
|
return isCustomConfigured, errMsg
|
||||||
|
}
|
||||||
|
customTools, err = parseToolsFile(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to parse tool file at %q: %w", opts.ToolsFile, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorContext(ctx, err.Error())
|
||||||
|
return isCustomConfigured, err
|
||||||
|
}
|
||||||
|
allToolsFiles = append(allToolsFiles, customTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Modify version string based on loaded configurations
|
||||||
|
if len(opts.PrebuiltConfigs) > 0 {
|
||||||
|
tag := "prebuilt"
|
||||||
|
if isCustomConfigured {
|
||||||
|
tag = "custom"
|
||||||
|
}
|
||||||
|
// prebuiltConfigs is already sorted above
|
||||||
|
for _, configName := range opts.PrebuiltConfigs {
|
||||||
|
opts.Cfg.Version += fmt.Sprintf("+%s.%s", tag, configName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge Everything
|
||||||
|
// This will error if custom tools collide with prebuilt tools
|
||||||
|
finalToolsFile, err := mergeToolsFiles(allToolsFiles...)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorContext(ctx, err.Error())
|
||||||
|
return isCustomConfigured, err
|
||||||
|
}
|
||||||
|
|
||||||
|
opts.Cfg.SourceConfigs = finalToolsFile.Sources
|
||||||
|
opts.Cfg.AuthServiceConfigs = finalToolsFile.AuthServices
|
||||||
|
opts.Cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels
|
||||||
|
opts.Cfg.ToolConfigs = finalToolsFile.Tools
|
||||||
|
opts.Cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||||
|
opts.Cfg.PromptConfigs = finalToolsFile.Prompts
|
||||||
|
|
||||||
|
return isCustomConfigured, nil
|
||||||
|
}
|
||||||
@@ -12,57 +12,38 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package cmd
|
package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCommandOptions(t *testing.T) {
|
func TestToolboxOptions(t *testing.T) {
|
||||||
w := io.Discard
|
w := io.Discard
|
||||||
tcs := []struct {
|
tcs := []struct {
|
||||||
desc string
|
desc string
|
||||||
isValid func(*Command) error
|
isValid func(*ToolboxOptions) error
|
||||||
option Option
|
option Option
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "with logger",
|
desc: "with logger",
|
||||||
isValid: func(c *Command) error {
|
isValid: func(o *ToolboxOptions) error {
|
||||||
if c.outStream != w || c.errStream != w {
|
if o.IOStreams.Out != w || o.IOStreams.ErrOut != w {
|
||||||
return errors.New("loggers do not match")
|
return errors.New("loggers do not match")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
option: WithStreams(w, w),
|
option: WithIOStreams(w, w),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
got, err := invokeProxyWithOption(tc.option)
|
got := NewToolboxOptions(tc.option)
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if err := tc.isValid(got); err != nil {
|
if err := tc.isValid(got); err != nil {
|
||||||
t.Errorf("option did not initialize command correctly: %v", err)
|
t.Errorf("option did not initialize command correctly: %v", err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func invokeProxyWithOption(o Option) (*Command, error) {
|
|
||||||
c := NewCommand(o)
|
|
||||||
// Keep the test output quiet
|
|
||||||
c.SilenceUsage = true
|
|
||||||
c.SilenceErrors = true
|
|
||||||
// Disable execute behavior
|
|
||||||
c.RunE = func(*cobra.Command, []string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.Execute()
|
|
||||||
return c, err
|
|
||||||
}
|
|
||||||
46
cmd/internal/persistent_flags.go
Normal file
46
cmd/internal/persistent_flags.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PersistentFlags sets up flags that are available for all commands and
|
||||||
|
// subcommands
|
||||||
|
// It is also used to set up persistent flags during subcommand unit tests
|
||||||
|
func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) {
|
||||||
|
persistentFlags := parentCmd.PersistentFlags()
|
||||||
|
|
||||||
|
persistentFlags.StringVar(&opts.ToolsFile, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
|
||||||
|
persistentFlags.StringSliceVar(&opts.ToolsFiles, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.")
|
||||||
|
persistentFlags.StringVar(&opts.ToolsFolder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.")
|
||||||
|
persistentFlags.Var(&opts.Cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
|
||||||
|
persistentFlags.Var(&opts.Cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
|
||||||
|
persistentFlags.BoolVar(&opts.Cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
|
||||||
|
persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')")
|
||||||
|
persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||||
|
// Fetch prebuilt tools sources to customize the help description
|
||||||
|
prebuiltHelp := fmt.Sprintf(
|
||||||
|
"Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.",
|
||||||
|
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
|
||||||
|
)
|
||||||
|
persistentFlags.StringSliceVar(&opts.PrebuiltConfigs, "prebuilt", []string{}, prebuiltHelp)
|
||||||
|
persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
|
||||||
|
}
|
||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/log"
|
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
@@ -30,28 +30,9 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RootCommand defines the interface for required by skills-generate subcommand.
|
// skillsCmd is the command for generating skills.
|
||||||
// This allows subcommands to access shared resources and functionality without
|
type skillsCmd struct {
|
||||||
// direct coupling to the root command's implementation.
|
|
||||||
type RootCommand interface {
|
|
||||||
// Config returns a copy of the current server configuration.
|
|
||||||
Config() server.ServerConfig
|
|
||||||
|
|
||||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
|
||||||
LoadConfig(ctx context.Context) error
|
|
||||||
|
|
||||||
// Setup initializes the runtime environment, including logging and telemetry.
|
|
||||||
// It returns the updated context and a shutdown function to be called when finished.
|
|
||||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
|
||||||
|
|
||||||
// Logger returns the logger instance.
|
|
||||||
Logger() log.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// Command is the command for generating skills.
|
|
||||||
type Command struct {
|
|
||||||
*cobra.Command
|
*cobra.Command
|
||||||
rootCmd RootCommand
|
|
||||||
name string
|
name string
|
||||||
description string
|
description string
|
||||||
toolset string
|
toolset string
|
||||||
@@ -59,15 +40,13 @@ type Command struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewCommand creates a new Command.
|
// NewCommand creates a new Command.
|
||||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
|
||||||
cmd := &Command{
|
cmd := &skillsCmd{}
|
||||||
rootCmd: rootCmd,
|
|
||||||
}
|
|
||||||
cmd.Command = &cobra.Command{
|
cmd.Command = &cobra.Command{
|
||||||
Use: "skills-generate",
|
Use: "skills-generate",
|
||||||
Short: "Generate skills from tool configurations",
|
Short: "Generate skills from tool configurations",
|
||||||
RunE: func(c *cobra.Command, args []string) error {
|
RunE: func(c *cobra.Command, args []string) error {
|
||||||
return cmd.run(c)
|
return run(cmd, opts)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,11 +60,11 @@ func NewCommand(rootCmd RootCommand) *cobra.Command {
|
|||||||
return cmd.Command
|
return cmd.Command
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Command) run(cmd *cobra.Command) error {
|
func run(cmd *skillsCmd, opts *internal.ToolboxOptions) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ctx, shutdown, err := c.rootCmd.Setup(ctx)
|
ctx, shutdown, err := opts.Setup(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -93,39 +72,37 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
_ = shutdown(ctx)
|
_ = shutdown(ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger := c.rootCmd.Logger()
|
_, err = opts.LoadConfig(ctx)
|
||||||
|
if err != nil {
|
||||||
// Load and merge tool configurations
|
|
||||||
if err := c.rootCmd.LoadConfig(ctx); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.MkdirAll(c.outputDir, 0755); err != nil {
|
if err := os.MkdirAll(cmd.outputDir, 0755); err != nil {
|
||||||
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name))
|
opts.Logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", cmd.name))
|
||||||
|
|
||||||
// Initialize toolbox and collect tools
|
// Initialize toolbox and collect tools
|
||||||
allTools, err := c.collectTools(ctx)
|
allTools, err := cmd.collectTools(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(allTools) == 0 {
|
if len(allTools) == 0 {
|
||||||
logger.InfoContext(ctx, "No tools found to generate.")
|
opts.Logger.InfoContext(ctx, "No tools found to generate.")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the combined skill directory
|
// Generate the combined skill directory
|
||||||
skillPath := filepath.Join(c.outputDir, c.name)
|
skillPath := filepath.Join(cmd.outputDir, cmd.name)
|
||||||
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
||||||
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,7 +110,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
assetsPath := filepath.Join(skillPath, "assets")
|
assetsPath := filepath.Join(skillPath, "assets")
|
||||||
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
||||||
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +118,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
scriptsPath := filepath.Join(skillPath, "scripts")
|
scriptsPath := filepath.Join(skillPath, "scripts")
|
||||||
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
||||||
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,10 +131,10 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
|
|
||||||
for _, toolName := range toolNames {
|
for _, toolName := range toolNames {
|
||||||
// Generate YAML config in asset directory
|
// Generate YAML config in asset directory
|
||||||
minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName)
|
minimizedContent, err := generateToolConfigYAML(opts.Cfg, toolName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +143,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
||||||
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
||||||
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,40 +152,40 @@ func (c *Command) run(cmd *cobra.Command) error {
|
|||||||
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
||||||
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
||||||
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate SKILL.md
|
// Generate SKILL.md
|
||||||
skillContent, err := generateSkillMarkdown(c.name, c.description, allTools)
|
skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||||
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
||||||
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
||||||
logger.ErrorContext(ctx, errMsg.Error())
|
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||||
return errMsg
|
return errMsg
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools)))
|
opts.Logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", cmd.name, len(allTools)))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) {
|
func (c *skillsCmd) collectTools(ctx context.Context, opts *internal.ToolboxOptions) (map[string]tools.Tool, error) {
|
||||||
// Initialize Resources
|
// Initialize Resources
|
||||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config())
|
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
||||||
}
|
}
|
||||||
@@ -12,17 +12,36 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package cmd
|
package skills
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"bytes"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func invokeCommand(args []string) (string, error) {
|
||||||
|
parentCmd := &cobra.Command{Use: "toolbox"}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf))
|
||||||
|
internal.PersistentFlags(parentCmd, opts)
|
||||||
|
|
||||||
|
cmd := NewCommand(opts)
|
||||||
|
parentCmd.AddCommand(cmd)
|
||||||
|
parentCmd.SetArgs(args)
|
||||||
|
|
||||||
|
err := parentCmd.Execute()
|
||||||
|
return buf.String(), err
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateSkill(t *testing.T) {
|
func TestGenerateSkill(t *testing.T) {
|
||||||
// Create a temporary directory for tests
|
// Create a temporary directory for tests
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
@@ -55,10 +74,7 @@ tools:
|
|||||||
"--description", "hello tool",
|
"--description", "hello tool",
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
got, err := invokeCommand(args)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
_, got, err := invokeCommandWithContext(ctx, args)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
||||||
}
|
}
|
||||||
@@ -136,7 +152,7 @@ func TestGenerateSkill_NoConfig(t *testing.T) {
|
|||||||
"--description", "test",
|
"--description", "test",
|
||||||
}
|
}
|
||||||
|
|
||||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
_, err := invokeCommand(args)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
||||||
}
|
}
|
||||||
@@ -170,7 +186,7 @@ func TestGenerateSkill_MissingArguments(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, got, err := invokeCommandWithContext(context.Background(), tt.args)
|
got, err := invokeCommand(tt.args)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
||||||
}
|
}
|
||||||
349
cmd/internal/tools_file.go
Normal file
349
cmd/internal/tools_file.go
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolsFile struct {
|
||||||
|
Sources server.SourceConfigs `yaml:"sources"`
|
||||||
|
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||||
|
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||||
|
Tools server.ToolConfigs `yaml:"tools"`
|
||||||
|
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||||
|
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseEnv replaces environment variables ${ENV_NAME} with their values.
|
||||||
|
// also support ${ENV_NAME:default_value}.
|
||||||
|
func parseEnv(input string) (string, error) {
|
||||||
|
re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
output := re.ReplaceAllStringFunc(input, func(match string) string {
|
||||||
|
parts := re.FindStringSubmatch(match)
|
||||||
|
|
||||||
|
// extract the variable name
|
||||||
|
variableName := parts[1]
|
||||||
|
if value, found := os.LookupEnv(variableName); found {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if len(parts) >= 4 && parts[2] != "" {
|
||||||
|
return parts[3]
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("environment variable not found: %q", variableName)
|
||||||
|
return ""
|
||||||
|
})
|
||||||
|
return output, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||||
|
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||||
|
var toolsFile ToolsFile
|
||||||
|
// Replace environment variables if found
|
||||||
|
output, err := parseEnv(string(raw))
|
||||||
|
if err != nil {
|
||||||
|
return toolsFile, fmt.Errorf("error parsing environment variables: %s", err)
|
||||||
|
}
|
||||||
|
raw = []byte(output)
|
||||||
|
|
||||||
|
raw, err = convertToolsFile(raw)
|
||||||
|
if err != nil {
|
||||||
|
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse contents
|
||||||
|
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
|
||||||
|
if err != nil {
|
||||||
|
return toolsFile, err
|
||||||
|
}
|
||||||
|
return toolsFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToolsFile(raw []byte) ([]byte, error) {
|
||||||
|
var input yaml.MapSlice
|
||||||
|
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
|
||||||
|
|
||||||
|
// convert to tools file v2
|
||||||
|
var buf bytes.Buffer
|
||||||
|
encoder := yaml.NewEncoder(&buf)
|
||||||
|
|
||||||
|
v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"}
|
||||||
|
for {
|
||||||
|
if err := decoder.Decode(&input); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
key, ok := item.Key.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
|
||||||
|
}
|
||||||
|
// check if the key is config file v1's key
|
||||||
|
if slices.Contains(v1keys, key) {
|
||||||
|
// check if value conversion to yaml.MapSlice successfully
|
||||||
|
// fields such as "tools" in toolsets might pass the first check but
|
||||||
|
// fail to convert to MapSlice
|
||||||
|
if slice, ok := item.Value.(yaml.MapSlice); ok {
|
||||||
|
// Deprecated: convert authSources to authServices
|
||||||
|
if key == "authSources" {
|
||||||
|
key = "authServices"
|
||||||
|
}
|
||||||
|
transformed, err := transformDocs(key, slice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// encode per-doc
|
||||||
|
for _, doc := range transformed {
|
||||||
|
if err := encoder.Encode(doc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// invalid input will be ignored
|
||||||
|
// we don't want to throw error here since the config could
|
||||||
|
// be valid but with a different order such as:
|
||||||
|
// ---
|
||||||
|
// tools:
|
||||||
|
// - tool_a
|
||||||
|
// kind: toolsets
|
||||||
|
// ---
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// this doc is already v2, encode to buf
|
||||||
|
if err := encoder.Encode(input); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// transformDocs transforms the configuration file from v1 format to v2
|
||||||
|
// yaml.MapSlice will preserve the order in a map
|
||||||
|
func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) {
|
||||||
|
var transformed []yaml.MapSlice
|
||||||
|
for _, entry := range input {
|
||||||
|
entryName, ok := entry.Key.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
|
||||||
|
}
|
||||||
|
entryBody := ProcessValue(entry.Value, kind == "toolsets")
|
||||||
|
|
||||||
|
currentTransformed := yaml.MapSlice{
|
||||||
|
{Key: "kind", Value: kind},
|
||||||
|
{Key: "name", Value: entryName},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the transformed body into our result
|
||||||
|
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
|
||||||
|
currentTransformed = append(currentTransformed, bodySlice...)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
|
||||||
|
}
|
||||||
|
transformed = append(transformed, currentTransformed)
|
||||||
|
}
|
||||||
|
return transformed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
|
||||||
|
func ProcessValue(v any, isToolset bool) any {
|
||||||
|
switch val := v.(type) {
|
||||||
|
case yaml.MapSlice:
|
||||||
|
// creating a new MapSlice is safer for recursive transformation
|
||||||
|
newVal := make(yaml.MapSlice, len(val))
|
||||||
|
for i, item := range val {
|
||||||
|
// Perform renaming
|
||||||
|
if item.Key == "kind" {
|
||||||
|
item.Key = "type"
|
||||||
|
}
|
||||||
|
// Recursive call for nested values (e.g., nested objects or lists)
|
||||||
|
item.Value = ProcessValue(item.Value, false)
|
||||||
|
newVal[i] = item
|
||||||
|
}
|
||||||
|
return newVal
|
||||||
|
case []any:
|
||||||
|
// Process lists: If it's a toolset top-level list, wrap it.
|
||||||
|
if isToolset {
|
||||||
|
return yaml.MapSlice{{Key: "tools", Value: val}}
|
||||||
|
}
|
||||||
|
// Otherwise, recurse into list items (to catch nested objects)
|
||||||
|
newVal := make([]any, len(val))
|
||||||
|
for i := range val {
|
||||||
|
newVal[i] = ProcessValue(val[i], false)
|
||||||
|
}
|
||||||
|
return newVal
|
||||||
|
default:
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeToolsFiles merges multiple ToolsFile structs into one.
|
||||||
|
// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets.
|
||||||
|
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
|
||||||
|
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||||
|
merged := ToolsFile{
|
||||||
|
Sources: make(server.SourceConfigs),
|
||||||
|
AuthServices: make(server.AuthServiceConfigs),
|
||||||
|
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||||
|
Tools: make(server.ToolConfigs),
|
||||||
|
Toolsets: make(server.ToolsetConfigs),
|
||||||
|
Prompts: make(server.PromptConfigs),
|
||||||
|
}
|
||||||
|
|
||||||
|
var conflicts []string
|
||||||
|
|
||||||
|
for fileIndex, file := range files {
|
||||||
|
// Check for conflicts and merge sources
|
||||||
|
for name, source := range file.Sources {
|
||||||
|
if _, exists := merged.Sources[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.Sources[name] = source
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for conflicts and merge authServices
|
||||||
|
for name, authService := range file.AuthServices {
|
||||||
|
if _, exists := merged.AuthServices[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.AuthServices[name] = authService
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for conflicts and merge embeddingModels
|
||||||
|
for name, em := range file.EmbeddingModels {
|
||||||
|
if _, exists := merged.EmbeddingModels[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.EmbeddingModels[name] = em
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for conflicts and merge tools
|
||||||
|
for name, tool := range file.Tools {
|
||||||
|
if _, exists := merged.Tools[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.Tools[name] = tool
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for conflicts and merge toolsets
|
||||||
|
for name, toolset := range file.Toolsets {
|
||||||
|
if _, exists := merged.Toolsets[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.Toolsets[name] = toolset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for conflicts and merge prompts
|
||||||
|
for name, prompt := range file.Prompts {
|
||||||
|
if _, exists := merged.Prompts[name]; exists {
|
||||||
|
conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1))
|
||||||
|
} else {
|
||||||
|
merged.Prompts[name] = prompt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If conflicts were detected, return an error
|
||||||
|
if len(conflicts) > 0 {
|
||||||
|
return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - "))
|
||||||
|
}
|
||||||
|
|
||||||
|
return merged, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndMergeToolsFiles loads multiple YAML files and merges them
|
||||||
|
func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) {
|
||||||
|
var toolsFiles []ToolsFile
|
||||||
|
|
||||||
|
for _, filePath := range filePaths {
|
||||||
|
buf, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsFile, err := parseToolsFile(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsFiles = append(toolsFiles, toolsFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
mergedFile, err := mergeToolsFiles(toolsFiles...)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mergedFile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndMergeToolsFolder loads all YAML files from a directory and merges them
|
||||||
|
func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) {
|
||||||
|
// Check if directory exists
|
||||||
|
info, err := os.Stat(folderPath)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err)
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find all YAML files in the directory
|
||||||
|
pattern := filepath.Join(folderPath, "*.yaml")
|
||||||
|
yamlFiles, err := filepath.Glob(pattern)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also find .yml files
|
||||||
|
ymlPattern := filepath.Join(folderPath, "*.yml")
|
||||||
|
ymlFiles, err := filepath.Glob(ymlPattern)
|
||||||
|
if err != nil {
|
||||||
|
return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine both file lists
|
||||||
|
allFiles := append(yamlFiles, ymlFiles...)
|
||||||
|
|
||||||
|
if len(allFiles) == 0 {
|
||||||
|
return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use existing LoadAndMergeToolsFiles function
|
||||||
|
return LoadAndMergeToolsFiles(ctx, allFiles)
|
||||||
|
}
|
||||||
2141
cmd/internal/tools_file_test.go
Normal file
2141
cmd/internal/tools_file_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
|||||||
// Copyright 2024 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Option is a function that configures a Command.
|
|
||||||
type Option func(*Command)
|
|
||||||
|
|
||||||
// WithStreams overrides the default writer.
|
|
||||||
func WithStreams(out, err io.Writer) Option {
|
|
||||||
return func(c *Command) {
|
|
||||||
c.outStream = out
|
|
||||||
c.errStream = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
886
cmd/root.go
886
cmd/root.go
File diff suppressed because it is too large
Load Diff
1743
cmd/root_test.go
1743
cmd/root_test.go
File diff suppressed because it is too large
Load Diff
@@ -1 +1 @@
|
|||||||
0.26.0
|
0.27.0
|
||||||
|
|||||||
@@ -234,7 +234,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"version = \"0.26.0\" # x-release-please-version\n",
|
"version = \"0.27.0\" # x-release-please-version\n",
|
||||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Make the binary executable\n",
|
"# Make the binary executable\n",
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ To install Toolbox as a binary on Linux (AMD64):
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
export VERSION=0.26.0
|
export VERSION=0.27.0
|
||||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||||
chmod +x toolbox
|
chmod +x toolbox
|
||||||
```
|
```
|
||||||
@@ -120,7 +120,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
export VERSION=0.26.0
|
export VERSION=0.27.0
|
||||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||||
chmod +x toolbox
|
chmod +x toolbox
|
||||||
```
|
```
|
||||||
@@ -131,7 +131,7 @@ To install Toolbox as a binary on macOS (Intel):
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
export VERSION=0.26.0
|
export VERSION=0.27.0
|
||||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||||
chmod +x toolbox
|
chmod +x toolbox
|
||||||
```
|
```
|
||||||
@@ -142,7 +142,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
|
|||||||
|
|
||||||
```cmd
|
```cmd
|
||||||
:: see releases page for other versions
|
:: see releases page for other versions
|
||||||
set VERSION=0.26.0
|
set VERSION=0.27.0
|
||||||
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -152,7 +152,7 @@ To install Toolbox as a binary on Windows (PowerShell):
|
|||||||
|
|
||||||
```powershell
|
```powershell
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
$VERSION = "0.26.0"
|
$VERSION = "0.27.0"
|
||||||
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -164,7 +164,7 @@ You can also install Toolbox as a container:
|
|||||||
|
|
||||||
```sh
|
```sh
|
||||||
# see releases page for other versions
|
# see releases page for other versions
|
||||||
export VERSION=0.26.0
|
export VERSION=0.27.0
|
||||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ To install from source, ensure you have the latest version of
|
|||||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
go install github.com/googleapis/genai-toolbox@v0.26.0
|
go install github.com/googleapis/genai-toolbox@v0.27.0
|
||||||
```
|
```
|
||||||
|
|
||||||
{{% /tab %}}
|
{{% /tab %}}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -100,19 +100,19 @@ After you install Looker in the MCP Store, resources and tools from the server a
|
|||||||
|
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -45,19 +45,19 @@ instance:
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
335
docs/en/how-to/connect-ide/oracle_mcp.md
Normal file
335
docs/en/how-to/connect-ide/oracle_mcp.md
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
---
|
||||||
|
title: "Oracle using MCP"
|
||||||
|
type: docs
|
||||||
|
weight: 2
|
||||||
|
description: >
|
||||||
|
Connect your IDE to Oracle using Toolbox.
|
||||||
|
---
|
||||||
|
|
||||||
|
[Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is
|
||||||
|
an open protocol for connecting Large Language Models (LLMs) to data sources
|
||||||
|
like Oracle. This guide covers how to use [MCP Toolbox for Databases][toolbox]
|
||||||
|
to expose your developer assistant tools to an Oracle instance:
|
||||||
|
|
||||||
|
* [Cursor][cursor]
|
||||||
|
* [Windsurf][windsurf] (Codium)
|
||||||
|
* [Visual Studio Code][vscode] (Copilot)
|
||||||
|
* [Cline][cline] (VS Code extension)
|
||||||
|
* [Claude desktop][claudedesktop]
|
||||||
|
* [Claude code][claudecode]
|
||||||
|
* [Gemini CLI][geminicli]
|
||||||
|
* [Gemini Code Assist][geminicodeassist]
|
||||||
|
|
||||||
|
[toolbox]: https://github.com/googleapis/genai-toolbox
|
||||||
|
[cursor]: #configure-your-mcp-client
|
||||||
|
[windsurf]: #configure-your-mcp-client
|
||||||
|
[vscode]: #configure-your-mcp-client
|
||||||
|
[cline]: #configure-your-mcp-client
|
||||||
|
[claudedesktop]: #configure-your-mcp-client
|
||||||
|
[claudecode]: #configure-your-mcp-client
|
||||||
|
[geminicli]: #configure-your-mcp-client
|
||||||
|
[geminicodeassist]: #configure-your-mcp-client
|
||||||
|
|
||||||
|
## Set up the database
|
||||||
|
|
||||||
|
1. Create or select an Oracle instance.
|
||||||
|
|
||||||
|
1. Create or reuse a database user and have the username and password ready.
|
||||||
|
|
||||||
|
## Install MCP Toolbox
|
||||||
|
|
||||||
|
1. Download the latest version of Toolbox as a binary. Select the [correct
|
||||||
|
binary](https://github.com/googleapis/genai-toolbox/releases) corresponding
|
||||||
|
to your OS and CPU architecture. You are required to use Toolbox version
|
||||||
|
V0.26.0+:
|
||||||
|
|
||||||
|
<!-- {x-release-please-start-version} -->
|
||||||
|
{{< tabpane persist=header >}}
|
||||||
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||||
|
{{< /tab >}}
|
||||||
|
|
||||||
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||||
|
{{< /tab >}}
|
||||||
|
|
||||||
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||||
|
{{< /tab >}}
|
||||||
|
|
||||||
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||||
|
{{< /tab >}}
|
||||||
|
{{< /tabpane >}}
|
||||||
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
1. Make the binary executable:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod +x toolbox
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Verify the installation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./toolbox --version
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configure your MCP Client
|
||||||
|
|
||||||
|
{{< tabpane text=true >}}
|
||||||
|
{{% tab header="Claude code" lang="en" %}}
|
||||||
|
|
||||||
|
1. Install [Claude
|
||||||
|
Code](https://docs.anthropic.com/en/docs/agents-and-tools/claude-code/overview).
|
||||||
|
1. Create a `.mcp.json` file in your project root if it doesn't exist.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": "",
|
||||||
|
"ORACLE_WALLET_LOCATION": "",
|
||||||
|
"ORACLE_USE_OCI": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Restart Claude code to apply the new configuration.
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Claude desktop" lang="en" %}}
|
||||||
|
|
||||||
|
1. Open Claude desktop and navigate to Settings.
|
||||||
|
1. Under the Developer tab, tap Edit Config to open the configuration file.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": "",
|
||||||
|
"ORACLE_WALLET_LOCATION": "",
|
||||||
|
"ORACLE_USE_OCI": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Restart Claude desktop.
|
||||||
|
1. From the new chat screen, you should see a hammer (MCP) icon appear with the
|
||||||
|
new MCP server available.
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Cline" lang="en" %}}
|
||||||
|
|
||||||
|
1. Open the Cline extension in VS Code and tap
|
||||||
|
the **MCP Servers** icon.
|
||||||
|
1. Tap Configure MCP Servers to open the configuration file.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": "",
|
||||||
|
"ORACLE_WALLET_LOCATION": "",
|
||||||
|
"ORACLE_USE_OCI": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
1. You should see a green active status after the server is successfully
|
||||||
|
connected.
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Cursor" lang="en" %}}
|
||||||
|
|
||||||
|
1. Create a `.cursor` directory in your project root if it doesn't exist.
|
||||||
|
1. Create a `.cursor/mcp.json` file if it doesn't exist and open it.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": "",
|
||||||
|
"ORACLE_WALLET_LOCATION": "",
|
||||||
|
"ORACLE_USE_OCI": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Cursor and navigate to **Settings > Cursor
|
||||||
|
Settings > MCP**. You should see a green active status after the server is
|
||||||
|
successfully connected.
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Visual Studio Code (Copilot)" lang="en" %}}
|
||||||
|
|
||||||
|
1. Open VS Code and
|
||||||
|
create a `.vscode` directory in your project root if it doesn't exist.
|
||||||
|
1. Create a `.vscode/mcp.json` file if it doesn't exist and open it.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"servers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracle","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Windsurf" lang="en" %}}
|
||||||
|
|
||||||
|
1. Open Windsurf and navigate to the
|
||||||
|
Cascade assistant.
|
||||||
|
1. Tap on the hammer (MCP) icon, then Configure to open the configuration file.
|
||||||
|
1. Add the following configuration, replace the environment variables with your
|
||||||
|
values, and save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": "",
|
||||||
|
"ORACLE_WALLET": "",
|
||||||
|
"ORACLE_WALLET_PASSWORD": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Gemini CLI" lang="en" %}}
|
||||||
|
|
||||||
|
1. Install the Gemini CLI.
|
||||||
|
1. In your working directory, create a folder named `.gemini`. Within it, create a `settings.json` file.
|
||||||
|
1. Add the following configuration, replace the environment variables with your values, and then save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{% /tab %}}
|
||||||
|
|
||||||
|
{{% tab header="Gemini Code Assist" lang="en" %}}
|
||||||
|
|
||||||
|
1. Install the Gemini Code Assist extension in Visual Studio Code.
|
||||||
|
1. Enable Agent Mode in Gemini Code Assist chat.
|
||||||
|
1. In your working directory, create a folder named `.gemini`. Within it, create a `settings.json` file.
|
||||||
|
1. Add the following configuration, replace the environment variables with your values, and then save:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"oracle": {
|
||||||
|
"command": "./PATH/TO/toolbox",
|
||||||
|
"args": ["--prebuilt","oracledb","--stdio"],
|
||||||
|
"env": {
|
||||||
|
"ORACLE_HOST": "",
|
||||||
|
"ORACLE_PORT": "1521",
|
||||||
|
"ORACLE_SERVICE": "",
|
||||||
|
"ORACLE_USER": "",
|
||||||
|
"ORACLE_PASSWORD": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
{{% /tab %}}
|
||||||
|
{{< /tabpane >}}
|
||||||
|
|
||||||
|
## Use Tools
|
||||||
|
|
||||||
|
Your AI tool is now connected to Oracle using MCP. Try asking your AI
|
||||||
|
assistant to list tables, create a table, or define and execute other SQL
|
||||||
|
statements.
|
||||||
|
|
||||||
|
The following tools are available to the LLM:
|
||||||
|
|
||||||
|
1. **list_tables**: lists tables and descriptions
|
||||||
|
1. **execute_sql**: execute any SQL statement
|
||||||
|
|
||||||
|
{{< notice note >}}
|
||||||
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
will adapt to the tools available, so this shouldn't affect most users.
|
||||||
|
{{< /notice >}}
|
||||||
@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/docs/overview).
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{< tab header="linux/amd64" lang="bash" >}}
|
{{< tab header="linux/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
|
|
||||||
{{< tab header="windows/amd64" lang="bash" >}}
|
{{< tab header="windows/amd64" lang="bash" >}}
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||||
{{< /tab >}}
|
{{< /tab >}}
|
||||||
{{< /tabpane >}}
|
{{< /tabpane >}}
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -692,6 +692,34 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `execute_cypher`: Executes a Cypher query.
|
* `execute_cypher`: Executes a Cypher query.
|
||||||
* `get_schema`: Retrieves the schema of the Neo4j database.
|
* `get_schema`: Retrieves the schema of the Neo4j database.
|
||||||
|
|
||||||
|
|
||||||
|
## Oracle
|
||||||
|
|
||||||
|
* `--prebuilt` value: `oracle`
|
||||||
|
* **Environment Variables:**
|
||||||
|
* `ORACLE_HOST`: The hostname or IP address of the Oracle server.
|
||||||
|
* `ORACLE_PORT`: The port number for the Oracle server (Default: 1521).
|
||||||
|
* `ORACLE_CONNECTION_STRING`: The
|
||||||
|
* `ORACLE_USER`: The username for the Oracle DB instance.
|
||||||
|
* `ORACLE_PASSWORD`: The password for the Oracle DB instance
|
||||||
|
* `ORACLE_WALLET`: The path to Oracle DB Wallet file for the databases that support this authentication type
|
||||||
|
* `ORACLE_USE_OCI`: true or false, The flag if the Oracle Database is deployed in cloud deployment. Default is false.
|
||||||
|
* **Permissions:**
|
||||||
|
* Database-level permissions (e.g., `SELECT`, `INSERT`) are required to execute queries.
|
||||||
|
* **Tools:**
|
||||||
|
* `execute_sql`: Executes a SQL query.
|
||||||
|
* `list_tables`: Lists tables in the database.
|
||||||
|
* `list_active_sessions`: Lists active database sessions.
|
||||||
|
* `get_query_plan`: Gets the execution plan for a SQL statement.
|
||||||
|
* `list_top_sql_by_resource`: Lists top SQL statements by resource usage.
|
||||||
|
* `list_tablespace_usage`: Lists tablespace usage.
|
||||||
|
* `list_invalid_objects`: Lists invalid objects.
|
||||||
|
* `list_active_sessions`: Lists active database sessions.
|
||||||
|
* `get_query_plan`: Gets the execution plan for a SQL statement.
|
||||||
|
* `list_top_sql_by_resource`: Lists top SQL statements by resource usage.
|
||||||
|
* `list_tablespace_usage`: Lists tablespace usage.
|
||||||
|
* `list_invalid_objects`: Lists invalid objects.
|
||||||
|
|
||||||
## Google Cloud Healthcare API
|
## Google Cloud Healthcare API
|
||||||
* `--prebuilt` value: `cloud-healthcare`
|
* `--prebuilt` value: `cloud-healthcare`
|
||||||
* **Environment Variables:**
|
* **Environment Variables:**
|
||||||
|
|||||||
242
docs/en/resources/sources/cockroachdb.md
Normal file
242
docs/en/resources/sources/cockroachdb.md
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
---
|
||||||
|
title: "CockroachDB"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
CockroachDB is a distributed SQL database built for cloud applications.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
[CockroachDB][crdb-docs] is a distributed SQL database designed for cloud-native applications. It provides strong consistency, horizontal scalability, and built-in resilience with automatic failover and recovery. CockroachDB uses the PostgreSQL wire protocol, making it compatible with many PostgreSQL tools and drivers while providing unique features like multi-region deployments and distributed transactions.
|
||||||
|
|
||||||
|
**Minimum Version:** CockroachDB v25.1 or later is recommended for full tool compatibility.
|
||||||
|
|
||||||
|
[crdb-docs]: https://www.cockroachlabs.com/docs/
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
|
||||||
|
- [`cockroachdb-sql`](../tools/cockroachdb/cockroachdb-sql.md)
|
||||||
|
Execute SQL queries as prepared statements in CockroachDB (alias for execute-sql).
|
||||||
|
|
||||||
|
- [`cockroachdb-execute-sql`](../tools/cockroachdb/cockroachdb-execute-sql.md)
|
||||||
|
Run parameterized SQL statements in CockroachDB.
|
||||||
|
|
||||||
|
- [`cockroachdb-list-schemas`](../tools/cockroachdb/cockroachdb-list-schemas.md)
|
||||||
|
List schemas in a CockroachDB database.
|
||||||
|
|
||||||
|
- [`cockroachdb-list-tables`](../tools/cockroachdb/cockroachdb-list-tables.md)
|
||||||
|
List tables in a CockroachDB database.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Database User
|
||||||
|
|
||||||
|
This source uses standard authentication. You will need to [create a CockroachDB user][crdb-users] to login to the database with. For CockroachDB Cloud deployments, SSL/TLS is required.
|
||||||
|
|
||||||
|
[crdb-users]: https://www.cockroachlabs.com/docs/stable/create-user.html
|
||||||
|
|
||||||
|
### SSL/TLS Configuration
|
||||||
|
|
||||||
|
CockroachDB Cloud clusters require SSL/TLS connections. Use the `queryParams` section to configure SSL settings:
|
||||||
|
|
||||||
|
- **For CockroachDB Cloud**: Use `sslmode: require` at minimum
|
||||||
|
- **For self-hosted with certificates**: Use `sslmode: verify-full` with certificate paths
|
||||||
|
- **For local development only**: Use `sslmode: disable` (not recommended for production)
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
type: cockroachdb
|
||||||
|
host: your-cluster.cockroachlabs.cloud
|
||||||
|
port: "26257"
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: defaultdb
|
||||||
|
maxRetries: 5
|
||||||
|
retryBaseDelay: 500ms
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
application_name: my-app
|
||||||
|
|
||||||
|
# MCP Security Settings (recommended for production)
|
||||||
|
readOnlyMode: true # Read-only by default (MCP best practice)
|
||||||
|
enableWriteMode: false # Set to true to allow write operations
|
||||||
|
maxRowLimit: 1000 # Limit query results
|
||||||
|
queryTimeoutSec: 30 # Prevent long-running queries
|
||||||
|
enableTelemetry: true # Enable observability
|
||||||
|
telemetryVerbose: false # Set true for detailed logs
|
||||||
|
clusterID: "my-cluster" # Optional identifier
|
||||||
|
|
||||||
|
tools:
|
||||||
|
list_expenses:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: List all expenses
|
||||||
|
statement: SELECT id, description, amount, category FROM expenses WHERE user_id = $1
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's ID
|
||||||
|
|
||||||
|
describe_expenses:
|
||||||
|
type: cockroachdb-describe-table
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Describe the expenses table schema
|
||||||
|
|
||||||
|
list_expenses_indexes:
|
||||||
|
type: cockroachdb-list-indexes
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: List indexes on the expenses table
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Parameters
|
||||||
|
|
||||||
|
### Required Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Description |
|
||||||
|
|-----------|------|-------------|
|
||||||
|
| `type` | string | Must be `cockroachdb` |
|
||||||
|
| `host` | string | The hostname or IP address of the CockroachDB cluster |
|
||||||
|
| `port` | string | The port number (typically "26257") |
|
||||||
|
| `user` | string | The database user name |
|
||||||
|
| `database` | string | The database name to connect to |
|
||||||
|
|
||||||
|
### Optional Parameters
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `password` | string | "" | The database password (can be empty for certificate-based auth) |
|
||||||
|
| `maxRetries` | integer | 5 | Maximum number of connection retry attempts |
|
||||||
|
| `retryBaseDelay` | string | "500ms" | Base delay between retry attempts (exponential backoff) |
|
||||||
|
| `queryParams` | map | {} | Additional connection parameters (e.g., SSL configuration) |
|
||||||
|
|
||||||
|
### MCP Security Parameters
|
||||||
|
|
||||||
|
CockroachDB integration includes security features following the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) specification:
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `readOnlyMode` | boolean | true | Enables read-only mode by default (MCP requirement) |
|
||||||
|
| `enableWriteMode` | boolean | false | Explicitly enable write operations (INSERT/UPDATE/DELETE/CREATE/DROP) |
|
||||||
|
| `maxRowLimit` | integer | 1000 | Maximum rows returned per SELECT query (auto-adds LIMIT clause) |
|
||||||
|
| `queryTimeoutSec` | integer | 30 | Query timeout in seconds to prevent long-running queries |
|
||||||
|
| `enableTelemetry` | boolean | true | Enable structured logging of tool invocations |
|
||||||
|
| `telemetryVerbose` | boolean | false | Enable detailed JSON telemetry output |
|
||||||
|
| `clusterID` | string | "" | Optional cluster identifier for telemetry |
|
||||||
|
|
||||||
|
### Query Parameters
|
||||||
|
|
||||||
|
Common query parameters for CockroachDB connections:
|
||||||
|
|
||||||
|
| Parameter | Values | Description |
|
||||||
|
|-----------|--------|-------------|
|
||||||
|
| `sslmode` | `disable`, `require`, `verify-ca`, `verify-full` | SSL/TLS mode (CockroachDB Cloud requires `require` or higher) |
|
||||||
|
| `sslrootcert` | file path | Path to root certificate for SSL verification |
|
||||||
|
| `sslcert` | file path | Path to client certificate |
|
||||||
|
| `sslkey` | file path | Path to client key |
|
||||||
|
| `application_name` | string | Application name for connection tracking |
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Security and MCP Compliance
|
||||||
|
|
||||||
|
**Read-Only by Default**: The integration follows MCP best practices by defaulting to read-only mode. This prevents accidental data modifications:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
readOnlyMode: true # Default behavior
|
||||||
|
enableWriteMode: false # Explicit write opt-in required
|
||||||
|
```
|
||||||
|
|
||||||
|
To enable write operations:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
readOnlyMode: false # Disable read-only protection
|
||||||
|
enableWriteMode: true # Explicitly allow writes
|
||||||
|
```
|
||||||
|
|
||||||
|
**Query Limits**: Automatic row limits prevent excessive data retrieval:
|
||||||
|
- SELECT queries automatically get `LIMIT 1000` appended (configurable via `maxRowLimit`)
|
||||||
|
- Queries are terminated after 30 seconds (configurable via `queryTimeoutSec`)
|
||||||
|
|
||||||
|
**Observability**: Structured telemetry provides visibility into tool usage:
|
||||||
|
- Tool invocations are logged with status, latency, and row counts
|
||||||
|
- SQL queries are redacted to protect sensitive values
|
||||||
|
- Set `telemetryVerbose: true` for detailed JSON logs
|
||||||
|
|
||||||
|
### Use UUID Primary Keys
|
||||||
|
|
||||||
|
CockroachDB performs best with UUID primary keys rather than sequential integers to avoid transaction hotspots:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE expenses (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
description TEXT,
|
||||||
|
amount DECIMAL(10,2)
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Automatic Transaction Retry
|
||||||
|
|
||||||
|
This source uses the official `cockroach-go/v2` library which provides automatic transaction retry for serialization conflicts. For write operations requiring explicit transaction control, tools can use the `ExecuteTxWithRetry` method.
|
||||||
|
|
||||||
|
### Multi-Region Deployments
|
||||||
|
|
||||||
|
CockroachDB supports multi-region deployments with automatic data distribution. Configure your cluster's regions and survival goals separately from the Toolbox configuration. The source will connect to any node in the cluster.
|
||||||
|
|
||||||
|
### Connection Pooling
|
||||||
|
|
||||||
|
The source maintains a connection pool to the CockroachDB cluster. The pool automatically handles:
|
||||||
|
- Load balancing across cluster nodes
|
||||||
|
- Connection retry with exponential backoff
|
||||||
|
- Health checking of connections
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### SSL/TLS Errors
|
||||||
|
|
||||||
|
If you encounter "server requires encryption" errors:
|
||||||
|
|
||||||
|
1. For CockroachDB Cloud, ensure `sslmode` is set to `require` or higher:
|
||||||
|
```yaml
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
```
|
||||||
|
|
||||||
|
2. For certificate verification, download your cluster's root certificate and configure:
|
||||||
|
```yaml
|
||||||
|
queryParams:
|
||||||
|
sslmode: verify-full
|
||||||
|
sslrootcert: /path/to/ca.crt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Connection Timeouts
|
||||||
|
|
||||||
|
If experiencing connection timeouts:
|
||||||
|
|
||||||
|
1. Check network connectivity to the CockroachDB cluster
|
||||||
|
2. Verify firewall rules allow connections on port 26257
|
||||||
|
3. For CockroachDB Cloud, ensure IP allowlisting is configured
|
||||||
|
4. Increase `maxRetries` or `retryBaseDelay` if needed
|
||||||
|
|
||||||
|
### Transaction Retry Errors
|
||||||
|
|
||||||
|
CockroachDB may encounter serializable transaction conflicts. The integration automatically handles these retries using the cockroach-go library. If you see retry-related errors, check:
|
||||||
|
|
||||||
|
1. Database load and contention
|
||||||
|
2. Query patterns that might cause conflicts
|
||||||
|
3. Consider using `SELECT FOR UPDATE` for explicit locking
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [CockroachDB Documentation](https://www.cockroachlabs.com/docs/)
|
||||||
|
- [CockroachDB Best Practices](https://www.cockroachlabs.com/docs/stable/performance-best-practices-overview.html)
|
||||||
|
- [Multi-Region Capabilities](https://www.cockroachlabs.com/docs/stable/multiregion-overview.html)
|
||||||
|
- [Connection Parameters](https://www.cockroachlabs.com/docs/stable/connection-parameters.html)
|
||||||
@@ -87,28 +87,41 @@ using a TNS (Transparent Network Substrate) alias.
|
|||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
This example demonstrates the four connection methods you could choose from:
|
### 1. Basic Connection (Host, Port, and Service Name)
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
kind: sources
|
sources:
|
||||||
name: my-oracle-source
|
my-oracle-source:
|
||||||
type: oracle
|
kind: oracle
|
||||||
|
|
||||||
# --- Choose one connection method ---
|
|
||||||
# 1. Host, Port, and Service Name
|
|
||||||
host: 127.0.0.1
|
host: 127.0.0.1
|
||||||
port: 1521
|
port: 1521
|
||||||
serviceName: XEPDB1
|
serviceName: XEPDB1
|
||||||
|
|
||||||
# 2. Direct Connection String
|
|
||||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
|
||||||
|
|
||||||
# 3. TNS Alias (requires tnsnames.ora)
|
|
||||||
tnsAlias: "MY_DB_ALIAS"
|
|
||||||
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
|
|
||||||
|
|
||||||
user: ${USER_NAME}
|
user: ${USER_NAME}
|
||||||
password: ${PASSWORD}
|
password: ${PASSWORD}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Direct Connection String
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my-oracle-source:
|
||||||
|
kind: oracle
|
||||||
|
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||||
|
user: ${USER_NAME}
|
||||||
|
password: ${PASSWORD}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. TNS Alias (requires tnsnames.ora)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my-oracle-source:
|
||||||
|
kind: oracle
|
||||||
|
tnsAlias: "MY_DB_ALIAS"
|
||||||
|
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
|
||||||
|
user: ${USER_NAME}
|
||||||
|
password: ${PASSWORD}
|
||||||
|
useOCI: true # tnsAlias requires useOCI to be true
|
||||||
|
|
||||||
# Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client)
|
# Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client)
|
||||||
```
|
```
|
||||||
@@ -168,3 +181,4 @@ instead of hardcoding your secrets into the configuration file.
|
|||||||
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
||||||
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
||||||
| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). |
|
| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). |
|
||||||
|
| walletLocation | string | false | Path to the directory containing the wallet files for the pure Go driver (`useOCI: false`). |
|
||||||
|
|||||||
273
docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md
Normal file
273
docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
---
|
||||||
|
title: "cockroachdb-execute-sql"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
Execute ad-hoc SQL statements against a CockroachDB database.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
A `cockroachdb-execute-sql` tool executes ad-hoc SQL statements against a CockroachDB database. This tool is designed for interactive workflows where the SQL query is provided dynamically at runtime, making it ideal for developer assistance and exploratory data analysis.
|
||||||
|
|
||||||
|
The tool takes a single `sql` parameter containing the SQL statement to execute and returns the query results.
|
||||||
|
|
||||||
|
> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. For production use cases with predefined queries, use [cockroachdb-sql](./cockroachdb-sql.md) instead.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
type: cockroachdb
|
||||||
|
host: your-cluster.cockroachlabs.cloud
|
||||||
|
port: "26257"
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: defaultdb
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
|
||||||
|
tools:
|
||||||
|
execute_sql:
|
||||||
|
type: cockroachdb-execute-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Execute any SQL statement against the CockroachDB database
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### Simple SELECT Query
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SELECT * FROM users LIMIT 10"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Query with Aggregations
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SELECT category, COUNT(*) as count, SUM(amount) as total FROM expenses GROUP BY category ORDER BY total DESC"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Introspection
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SHOW TABLES"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SHOW COLUMNS FROM expenses"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Region Information
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SHOW REGIONS FROM DATABASE defaultdb"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SHOW ZONE CONFIGURATIONS"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## CockroachDB-Specific Features
|
||||||
|
|
||||||
|
### Check Cluster Version
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SELECT version()"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### View Node Status
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SELECT node_id, address, locality, is_live FROM crdb_internal.gossip_nodes"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Replication Status
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SELECT range_id, start_key, end_key, replicas, lease_holder FROM crdb_internal.ranges LIMIT 10"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### View Table Regions
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sql": "SHOW REGIONS FROM TABLE expenses"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `type` | string | Must be `cockroachdb-execute-sql` |
|
||||||
|
| `source` | string | Name of the CockroachDB source to use |
|
||||||
|
| `description` | string | Human-readable description for the LLM |
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `authRequired` | array | List of authentication services required |
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
The tool accepts a single runtime parameter:
|
||||||
|
|
||||||
|
| Parameter | Type | Description |
|
||||||
|
|-----------|------|-------------|
|
||||||
|
| `sql` | string | The SQL statement to execute |
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Use for Exploration, Not Production
|
||||||
|
|
||||||
|
This tool is ideal for:
|
||||||
|
- Interactive database exploration
|
||||||
|
- Ad-hoc analysis and reporting
|
||||||
|
- Debugging and troubleshooting
|
||||||
|
- Schema inspection
|
||||||
|
|
||||||
|
For production use cases, use [cockroachdb-sql](./cockroachdb-sql.md) with parameterized queries.
|
||||||
|
|
||||||
|
### Be Cautious with Data Modification
|
||||||
|
|
||||||
|
While this tool can execute any SQL statement, be careful with:
|
||||||
|
- `INSERT`, `UPDATE`, `DELETE` statements
|
||||||
|
- `DROP` or `ALTER` statements
|
||||||
|
- Schema changes in production
|
||||||
|
|
||||||
|
### Use LIMIT for Large Results
|
||||||
|
|
||||||
|
Always use `LIMIT` clauses when exploring data:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT * FROM large_table LIMIT 100
|
||||||
|
```
|
||||||
|
|
||||||
|
### Leverage CockroachDB's SQL Extensions
|
||||||
|
|
||||||
|
CockroachDB supports PostgreSQL syntax plus extensions:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Show database survival goal
|
||||||
|
SHOW SURVIVAL GOAL FROM DATABASE defaultdb;
|
||||||
|
|
||||||
|
-- View zone configurations
|
||||||
|
SHOW ZONE CONFIGURATION FOR TABLE expenses;
|
||||||
|
|
||||||
|
-- Check table localities
|
||||||
|
SHOW CREATE TABLE expenses;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The tool will return descriptive errors for:
|
||||||
|
- **Syntax errors**: Invalid SQL syntax
|
||||||
|
- **Permission errors**: Insufficient user privileges
|
||||||
|
- **Connection errors**: Network or authentication issues
|
||||||
|
- **Runtime errors**: Constraint violations, type mismatches, etc.
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
### SQL Injection Risk
|
||||||
|
|
||||||
|
Since this tool executes arbitrary SQL, it should only be used with:
|
||||||
|
- Trusted users in interactive sessions
|
||||||
|
- Human-in-the-loop workflows
|
||||||
|
- Development and testing environments
|
||||||
|
|
||||||
|
Never expose this tool directly to end users without proper authorization controls.
|
||||||
|
|
||||||
|
### Use Authentication
|
||||||
|
|
||||||
|
Configure the `authRequired` field to restrict access:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
execute_sql:
|
||||||
|
type: cockroachdb-execute-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Execute SQL statements
|
||||||
|
authRequired:
|
||||||
|
- my-auth-service
|
||||||
|
```
|
||||||
|
|
||||||
|
### Read-Only Users
|
||||||
|
|
||||||
|
For safer exploration, create read-only database users:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE USER readonly_user;
|
||||||
|
GRANT SELECT ON DATABASE defaultdb TO readonly_user;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
### Database Administration
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- View database size
|
||||||
|
SELECT
|
||||||
|
table_name,
|
||||||
|
pg_size_pretty(pg_total_relation_size(table_name::regclass)) AS size
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = 'public'
|
||||||
|
ORDER BY pg_total_relation_size(table_name::regclass) DESC;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Analysis
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Find slow queries
|
||||||
|
SELECT query, count, mean_latency
|
||||||
|
FROM crdb_internal.statement_statistics
|
||||||
|
WHERE mean_latency > INTERVAL '1 second'
|
||||||
|
ORDER BY mean_latency DESC
|
||||||
|
LIMIT 10;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Data Quality Checks
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Find NULL values
|
||||||
|
SELECT COUNT(*) as null_count
|
||||||
|
FROM expenses
|
||||||
|
WHERE description IS NULL OR amount IS NULL;
|
||||||
|
|
||||||
|
-- Find duplicates
|
||||||
|
SELECT user_id, email, COUNT(*) as count
|
||||||
|
FROM users
|
||||||
|
GROUP BY user_id, email
|
||||||
|
HAVING COUNT(*) > 1;
|
||||||
|
```
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
- [cockroachdb-sql](./cockroachdb-sql.md) - For parameterized, production-ready queries
|
||||||
|
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||||
|
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||||
|
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||||
|
- [CockroachDB SQL Reference](https://www.cockroachlabs.com/docs/stable/sql-statements.html) - Official SQL documentation
|
||||||
305
docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md
Normal file
305
docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
---
|
||||||
|
title: "cockroachdb-list-schemas"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
List schemas in a CockroachDB database.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
The `cockroachdb-list-schemas` tool retrieves a list of schemas (namespaces) in a CockroachDB database. Schemas are used to organize database objects such as tables, views, and functions into logical groups.
|
||||||
|
|
||||||
|
This tool is useful for:
|
||||||
|
- Understanding database organization
|
||||||
|
- Discovering available schemas
|
||||||
|
- Multi-tenant application analysis
|
||||||
|
- Schema-level access control planning
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
type: cockroachdb
|
||||||
|
host: your-cluster.cockroachlabs.cloud
|
||||||
|
port: "26257"
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: defaultdb
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
|
||||||
|
tools:
|
||||||
|
list_schemas:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: List all schemas in the database
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `type` | string | Must be `cockroachdb-list-schemas` |
|
||||||
|
| `source` | string | Name of the CockroachDB source to use |
|
||||||
|
| `description` | string | Human-readable description for the LLM |
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `authRequired` | array | List of authentication services required |
|
||||||
|
|
||||||
|
## Output Structure
|
||||||
|
|
||||||
|
The tool returns a list of schemas with the following information:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"catalog_name": "defaultdb",
|
||||||
|
"schema_name": "public",
|
||||||
|
"is_user_defined": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"catalog_name": "defaultdb",
|
||||||
|
"schema_name": "analytics",
|
||||||
|
"is_user_defined": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `catalog_name` | string | The database (catalog) name |
|
||||||
|
| `schema_name` | string | The schema name |
|
||||||
|
| `is_user_defined` | boolean | Whether this is a user-created schema (excludes system schemas) |
|
||||||
|
|
||||||
|
## Usage Example
|
||||||
|
|
||||||
|
```json
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
No parameters are required. The tool automatically lists all user-defined schemas.
|
||||||
|
|
||||||
|
## Default Schemas
|
||||||
|
|
||||||
|
CockroachDB includes several standard schemas:
|
||||||
|
|
||||||
|
- **`public`**: The default schema for user objects
|
||||||
|
- **`pg_catalog`**: PostgreSQL system catalog (excluded from results)
|
||||||
|
- **`information_schema`**: SQL standard metadata views (excluded from results)
|
||||||
|
- **`crdb_internal`**: CockroachDB internal metadata (excluded from results)
|
||||||
|
- **`pg_extension`**: PostgreSQL extension objects (excluded from results)
|
||||||
|
|
||||||
|
The tool filters out system schemas and only returns user-defined schemas.
|
||||||
|
|
||||||
|
## Schema Management in CockroachDB
|
||||||
|
|
||||||
|
### Creating Schemas
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE SCHEMA analytics;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using Schemas
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Create table in specific schema
|
||||||
|
CREATE TABLE analytics.revenue (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
amount DECIMAL(10,2),
|
||||||
|
date DATE
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Query from specific schema
|
||||||
|
SELECT * FROM analytics.revenue;
|
||||||
|
```
|
||||||
|
|
||||||
|
### Schema Search Path
|
||||||
|
|
||||||
|
The search path determines which schemas are searched for unqualified object names:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Show current search path
|
||||||
|
SHOW search_path;
|
||||||
|
|
||||||
|
-- Set search path
|
||||||
|
SET search_path = analytics, public;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Multi-Tenant Applications
|
||||||
|
|
||||||
|
Schemas are commonly used for multi-tenant applications:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Create schema per tenant
|
||||||
|
CREATE SCHEMA tenant_acme;
|
||||||
|
CREATE SCHEMA tenant_globex;
|
||||||
|
|
||||||
|
-- Create same table structure in each schema
|
||||||
|
CREATE TABLE tenant_acme.orders (...);
|
||||||
|
CREATE TABLE tenant_globex.orders (...);
|
||||||
|
```
|
||||||
|
|
||||||
|
The `cockroachdb-list-schemas` tool helps discover all tenant schemas:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
list_tenants:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
List all tenant schemas in the database.
|
||||||
|
Each schema represents a separate tenant's data namespace.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Use Schemas for Organization
|
||||||
|
|
||||||
|
Group related tables into schemas:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE SCHEMA sales;
|
||||||
|
CREATE SCHEMA inventory;
|
||||||
|
CREATE SCHEMA hr;
|
||||||
|
|
||||||
|
CREATE TABLE sales.orders (...);
|
||||||
|
CREATE TABLE inventory.products (...);
|
||||||
|
CREATE TABLE hr.employees (...);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Schema Naming Conventions
|
||||||
|
|
||||||
|
Use clear, descriptive schema names:
|
||||||
|
- Lowercase names
|
||||||
|
- Use underscores for multi-word names
|
||||||
|
- Avoid reserved keywords
|
||||||
|
- Use prefixes for grouped schemas (e.g., `tenant_`, `app_`)
|
||||||
|
|
||||||
|
### Schema-Level Permissions
|
||||||
|
|
||||||
|
Schemas enable fine-grained access control:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Grant access to specific schema
|
||||||
|
GRANT USAGE ON SCHEMA analytics TO analyst_role;
|
||||||
|
GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analyst_role;
|
||||||
|
|
||||||
|
-- Revoke access
|
||||||
|
REVOKE ALL ON SCHEMA hr FROM public;
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with Other Tools
|
||||||
|
|
||||||
|
### Combined with List Tables
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
list_schemas:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: List all schemas first
|
||||||
|
|
||||||
|
list_tables:
|
||||||
|
type: cockroachdb-list-tables
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
List tables in the database.
|
||||||
|
Use list_schemas first to understand schema organization.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Schema Discovery Workflow
|
||||||
|
|
||||||
|
1. Call `cockroachdb-list-schemas` to discover schemas
|
||||||
|
2. Call `cockroachdb-list-tables` to see tables in each schema
|
||||||
|
3. Generate queries using fully qualified names: `schema.table`
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
### Discover Database Structure
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
discover_schemas:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
Discover how the database is organized into schemas.
|
||||||
|
Use this to understand the logical grouping of tables.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Tenant Analysis
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
list_tenant_schemas:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
List all tenant schemas (each tenant has their own schema).
|
||||||
|
Schema names follow the pattern: tenant_<company_name>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Schema Migration Planning
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
audit_schemas:
|
||||||
|
type: cockroachdb-list-schemas
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
Audit existing schemas before migration.
|
||||||
|
Identifies all schemas that need to be migrated.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The tool handles common errors:
|
||||||
|
- **Connection errors**: Returns connection failure details
|
||||||
|
- **Permission errors**: Returns error if user lacks USAGE privilege
|
||||||
|
- **Empty results**: Returns empty array if no user schemas exist
|
||||||
|
|
||||||
|
## Permissions Required
|
||||||
|
|
||||||
|
To list schemas, the user needs:
|
||||||
|
- `CONNECT` privilege on the database
|
||||||
|
- No specific schema privileges required for listing
|
||||||
|
|
||||||
|
To query objects within schemas, the user needs:
|
||||||
|
- `USAGE` privilege on the schema
|
||||||
|
- Appropriate object privileges (SELECT, INSERT, etc.)
|
||||||
|
|
||||||
|
## CockroachDB-Specific Features
|
||||||
|
|
||||||
|
### System Schemas
|
||||||
|
|
||||||
|
CockroachDB includes PostgreSQL-compatible system schemas plus CockroachDB-specific ones:
|
||||||
|
|
||||||
|
- `crdb_internal.*`: CockroachDB internal metadata and statistics
|
||||||
|
- `pg_catalog.*`: PostgreSQL system catalog
|
||||||
|
- `information_schema.*`: SQL standard information schema
|
||||||
|
|
||||||
|
These are automatically filtered from the results.
|
||||||
|
|
||||||
|
### User-Defined Flag
|
||||||
|
|
||||||
|
The `is_user_defined` field helps distinguish:
|
||||||
|
- `true`: User-created schemas
|
||||||
|
- `false`: System schemas (already filtered out)
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries
|
||||||
|
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL
|
||||||
|
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||||
|
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||||
|
- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Official documentation
|
||||||
344
docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md
Normal file
344
docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
---
|
||||||
|
title: "cockroachdb-list-tables"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
List tables in a CockroachDB database with schema details.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
The `cockroachdb-list-tables` tool retrieves a list of tables from a CockroachDB database. It provides detailed information about table structure, including columns, constraints, indexes, and foreign key relationships.
|
||||||
|
|
||||||
|
This tool is useful for:
|
||||||
|
- Database schema discovery
|
||||||
|
- Understanding table relationships
|
||||||
|
- Generating context for AI-powered database queries
|
||||||
|
- Documentation and analysis
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
type: cockroachdb
|
||||||
|
host: your-cluster.cockroachlabs.cloud
|
||||||
|
port: "26257"
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: defaultdb
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
|
||||||
|
tools:
|
||||||
|
list_all_tables:
|
||||||
|
type: cockroachdb-list-tables
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: List all user tables in the database with their structure
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `type` | string | Must be `cockroachdb-list-tables` |
|
||||||
|
| `source` | string | Name of the CockroachDB source to use |
|
||||||
|
| `description` | string | Human-readable description for the LLM |
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `authRequired` | array | List of authentication services required |
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
The tool accepts optional runtime parameters:
|
||||||
|
|
||||||
|
| Parameter | Type | Default | Description |
|
||||||
|
|-----------|------|---------|-------------|
|
||||||
|
| `table_names` | array | all tables | List of specific table names to retrieve |
|
||||||
|
| `output_format` | string | "detailed" | Output format: "simple" or "detailed" |
|
||||||
|
|
||||||
|
## Output Formats
|
||||||
|
|
||||||
|
### Simple Format
|
||||||
|
|
||||||
|
Returns basic table information:
|
||||||
|
- Table name
|
||||||
|
- Row count estimate
|
||||||
|
- Size information
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_names": ["users"],
|
||||||
|
"output_format": "simple"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Detailed Format (Default)
|
||||||
|
|
||||||
|
Returns comprehensive table information:
|
||||||
|
- Table name and schema
|
||||||
|
- All columns with types and constraints
|
||||||
|
- Primary keys
|
||||||
|
- Foreign keys and relationships
|
||||||
|
- Indexes
|
||||||
|
- Check constraints
|
||||||
|
- Table size and row counts
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_names": ["users", "orders"],
|
||||||
|
"output_format": "detailed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Examples
|
||||||
|
|
||||||
|
### List All Tables
|
||||||
|
|
||||||
|
```json
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
### List Specific Tables
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_names": ["users", "orders", "expenses"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Simple Output
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"output_format": "simple"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output Structure
|
||||||
|
|
||||||
|
### Simple Format Output
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_name": "users",
|
||||||
|
"estimated_rows": 1000,
|
||||||
|
"size": "128 KB"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Detailed Format Output
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_name": "users",
|
||||||
|
"schema": "public",
|
||||||
|
"columns": [
|
||||||
|
{
|
||||||
|
"name": "id",
|
||||||
|
"type": "UUID",
|
||||||
|
"nullable": false,
|
||||||
|
"default": "gen_random_uuid()"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "email",
|
||||||
|
"type": "STRING",
|
||||||
|
"nullable": false,
|
||||||
|
"default": null
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "created_at",
|
||||||
|
"type": "TIMESTAMP",
|
||||||
|
"nullable": false,
|
||||||
|
"default": "now()"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"primary_key": ["id"],
|
||||||
|
"indexes": [
|
||||||
|
{
|
||||||
|
"name": "users_pkey",
|
||||||
|
"columns": ["id"],
|
||||||
|
"unique": true,
|
||||||
|
"primary": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "users_email_idx",
|
||||||
|
"columns": ["email"],
|
||||||
|
"unique": true,
|
||||||
|
"primary": false
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"foreign_keys": [],
|
||||||
|
"constraints": [
|
||||||
|
{
|
||||||
|
"name": "users_email_check",
|
||||||
|
"type": "CHECK",
|
||||||
|
"definition": "email ~* '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}$'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## CockroachDB-Specific Information
|
||||||
|
|
||||||
|
### UUID Primary Keys
|
||||||
|
|
||||||
|
The tool recognizes CockroachDB's recommended UUID primary key pattern:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE users (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
...
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Region Tables
|
||||||
|
|
||||||
|
For multi-region tables, the output includes locality information:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_name": "users",
|
||||||
|
"locality": "REGIONAL BY ROW",
|
||||||
|
"regions": ["us-east-1", "us-west-2", "eu-west-1"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Interleaved Tables
|
||||||
|
|
||||||
|
The tool shows parent-child relationships for interleaved tables (legacy feature):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_name": "order_items",
|
||||||
|
"interleaved_in": "orders"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Use for Schema Discovery
|
||||||
|
|
||||||
|
The tool is ideal for helping AI assistants understand your database structure:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
discover_schema:
|
||||||
|
type: cockroachdb-list-tables
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
Use this tool first to understand the database schema before generating queries.
|
||||||
|
It shows all tables, their columns, data types, and relationships.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Filter Large Schemas
|
||||||
|
|
||||||
|
For databases with many tables, specify relevant tables:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_names": ["users", "orders", "products"],
|
||||||
|
"output_format": "detailed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use Simple Format for Overviews
|
||||||
|
|
||||||
|
When you need just table names and sizes:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"output_format": "simple"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Excluded Tables
|
||||||
|
|
||||||
|
The tool automatically excludes system tables and schemas:
|
||||||
|
- `pg_catalog.*` - PostgreSQL system catalog
|
||||||
|
- `information_schema.*` - SQL standard information schema
|
||||||
|
- `crdb_internal.*` - CockroachDB internal tables
|
||||||
|
- `pg_extension.*` - PostgreSQL extension tables
|
||||||
|
|
||||||
|
Only user-created tables in the public schema (and other user schemas) are returned.
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The tool handles common errors:
|
||||||
|
- **Table not found**: Returns empty result for non-existent tables
|
||||||
|
- **Permission errors**: Returns error if user lacks SELECT privileges
|
||||||
|
- **Connection errors**: Returns connection failure details
|
||||||
|
|
||||||
|
## Integration with AI Assistants
|
||||||
|
|
||||||
|
### Prompt Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
list_tables:
|
||||||
|
type: cockroachdb-list-tables
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: |
|
||||||
|
Lists all tables in the database with detailed schema information.
|
||||||
|
Use this tool to understand:
|
||||||
|
- What tables exist
|
||||||
|
- What columns each table has
|
||||||
|
- Data types and constraints
|
||||||
|
- Relationships between tables (foreign keys)
|
||||||
|
- Available indexes
|
||||||
|
|
||||||
|
Always call this tool before generating SQL queries to ensure
|
||||||
|
you use correct table and column names.
|
||||||
|
```
|
||||||
|
|
||||||
|
## Common Use Cases
|
||||||
|
|
||||||
|
### Generate Context for Queries
|
||||||
|
|
||||||
|
```json
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
This provides comprehensive schema information that helps AI assistants generate accurate SQL queries.
|
||||||
|
|
||||||
|
### Analyze Table Structure
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"table_names": ["users"],
|
||||||
|
"output_format": "detailed"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Perfect for understanding a specific table's structure, constraints, and relationships.
|
||||||
|
|
||||||
|
### Quick Schema Overview
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"output_format": "simple"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Gets a quick list of tables with basic statistics.
|
||||||
|
|
||||||
|
## Performance Considerations
|
||||||
|
|
||||||
|
- **Simple format** is faster for large databases
|
||||||
|
- **Detailed format** queries system tables extensively
|
||||||
|
- Specifying `table_names` reduces query time
|
||||||
|
- Results are fetched in a single query for efficiency
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries
|
||||||
|
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL
|
||||||
|
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||||
|
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||||
|
- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Best practices
|
||||||
291
docs/en/resources/tools/cockroachdb/cockroachdb-sql.md
Normal file
291
docs/en/resources/tools/cockroachdb/cockroachdb-sql.md
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
---
|
||||||
|
title: "cockroachdb-sql"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
Execute parameterized SQL queries in CockroachDB.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
The `cockroachdb-sql` tool allows you to execute parameterized SQL queries against a CockroachDB database. This tool supports prepared statements with parameter binding, template parameters for dynamic query construction, and automatic transaction retry for resilience against serialization conflicts.
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my_cockroachdb:
|
||||||
|
type: cockroachdb
|
||||||
|
host: your-cluster.cockroachlabs.cloud
|
||||||
|
port: "26257"
|
||||||
|
user: myuser
|
||||||
|
password: mypassword
|
||||||
|
database: defaultdb
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
|
||||||
|
tools:
|
||||||
|
get_user_orders:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Get all orders for a specific user
|
||||||
|
statement: |
|
||||||
|
SELECT o.id, o.order_date, o.total_amount, o.status
|
||||||
|
FROM orders o
|
||||||
|
WHERE o.user_id = $1
|
||||||
|
ORDER BY o.order_date DESC
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The UUID of the user
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Required Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `type` | string | Must be `cockroachdb-sql` |
|
||||||
|
| `source` | string | Name of the CockroachDB source to use |
|
||||||
|
| `description` | string | Human-readable description of what the tool does |
|
||||||
|
| `statement` | string | The SQL query to execute |
|
||||||
|
|
||||||
|
### Optional Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `parameters` | array | List of parameter definitions for the query |
|
||||||
|
| `templateParameters` | array | List of template parameters for dynamic query construction |
|
||||||
|
| `authRequired` | array | List of authentication services required |
|
||||||
|
|
||||||
|
## Parameters
|
||||||
|
|
||||||
|
Parameters allow you to safely pass values into your SQL queries using prepared statements. CockroachDB uses PostgreSQL-style parameter placeholders: `$1`, `$2`, etc.
|
||||||
|
|
||||||
|
### Parameter Types
|
||||||
|
|
||||||
|
- `string`: Text values
|
||||||
|
- `number`: Numeric values (integers or decimals)
|
||||||
|
- `boolean`: True/false values
|
||||||
|
- `array`: Array of values
|
||||||
|
|
||||||
|
### Example with Multiple Parameters
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
filter_expenses:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Filter expenses by category and date range
|
||||||
|
statement: |
|
||||||
|
SELECT id, description, amount, category, expense_date
|
||||||
|
FROM expenses
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND category = $2
|
||||||
|
AND expense_date >= $3
|
||||||
|
AND expense_date <= $4
|
||||||
|
ORDER BY expense_date DESC
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's UUID
|
||||||
|
- name: category
|
||||||
|
type: string
|
||||||
|
description: Expense category (e.g., "Food", "Transport")
|
||||||
|
- name: start_date
|
||||||
|
type: string
|
||||||
|
description: Start date in YYYY-MM-DD format
|
||||||
|
- name: end_date
|
||||||
|
type: string
|
||||||
|
description: End date in YYYY-MM-DD format
|
||||||
|
```
|
||||||
|
|
||||||
|
## Template Parameters
|
||||||
|
|
||||||
|
Template parameters enable dynamic query construction by replacing placeholders in the SQL statement before parameter binding. This is useful for dynamic table names, column names, or query structure.
|
||||||
|
|
||||||
|
### Example with Template Parameters
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
get_column_data:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Get data from a specific column
|
||||||
|
statement: |
|
||||||
|
SELECT {{column_name}}
|
||||||
|
FROM {{table_name}}
|
||||||
|
WHERE user_id = $1
|
||||||
|
LIMIT 100
|
||||||
|
templateParameters:
|
||||||
|
- name: table_name
|
||||||
|
type: string
|
||||||
|
description: The table to query
|
||||||
|
- name: column_name
|
||||||
|
type: string
|
||||||
|
description: The column to retrieve
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's UUID
|
||||||
|
```
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
### Use UUID Primary Keys
|
||||||
|
|
||||||
|
CockroachDB performs best with UUID primary keys to avoid transaction hotspots:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE TABLE orders (
|
||||||
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||||
|
user_id UUID NOT NULL,
|
||||||
|
order_date TIMESTAMP DEFAULT now(),
|
||||||
|
total_amount DECIMAL(10,2)
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use Indexes for Performance
|
||||||
|
|
||||||
|
Create indexes on frequently queried columns:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE INDEX idx_orders_user_id ON orders(user_id);
|
||||||
|
CREATE INDEX idx_orders_date ON orders(order_date DESC);
|
||||||
|
```
|
||||||
|
|
||||||
|
### Use JOINs Efficiently
|
||||||
|
|
||||||
|
CockroachDB supports standard SQL JOINs. Keep joins efficient by:
|
||||||
|
- Adding appropriate indexes
|
||||||
|
- Using UUIDs for foreign keys
|
||||||
|
- Limiting result sets with WHERE clauses
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
get_user_with_orders:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Get user details with their recent orders
|
||||||
|
statement: |
|
||||||
|
SELECT u.name, u.email, o.id as order_id, o.order_date, o.total_amount
|
||||||
|
FROM users u
|
||||||
|
LEFT JOIN orders o ON u.id = o.user_id
|
||||||
|
WHERE u.id = $1
|
||||||
|
ORDER BY o.order_date DESC
|
||||||
|
LIMIT 10
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's UUID
|
||||||
|
```
|
||||||
|
|
||||||
|
### Handle NULL Values
|
||||||
|
|
||||||
|
Use COALESCE or NULL checks when dealing with nullable columns:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT id, description, COALESCE(notes, 'No notes') as notes
|
||||||
|
FROM expenses
|
||||||
|
WHERE user_id = $1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Error Handling
|
||||||
|
|
||||||
|
The tool automatically handles:
|
||||||
|
- **Connection errors**: Retried with exponential backoff
|
||||||
|
- **Serialization conflicts**: Automatically retried using cockroach-go library
|
||||||
|
- **Invalid parameters**: Returns descriptive error messages
|
||||||
|
- **SQL syntax errors**: Returns database error details
|
||||||
|
|
||||||
|
## Advanced Usage
|
||||||
|
|
||||||
|
### Aggregations
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
expense_summary:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Get expense summary by category for a user
|
||||||
|
statement: |
|
||||||
|
SELECT
|
||||||
|
category,
|
||||||
|
COUNT(*) as count,
|
||||||
|
SUM(amount) as total_amount,
|
||||||
|
AVG(amount) as avg_amount
|
||||||
|
FROM expenses
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND expense_date >= $2
|
||||||
|
GROUP BY category
|
||||||
|
ORDER BY total_amount DESC
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's UUID
|
||||||
|
- name: start_date
|
||||||
|
type: string
|
||||||
|
description: Start date in YYYY-MM-DD format
|
||||||
|
```
|
||||||
|
|
||||||
|
### Window Functions
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
running_total:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Get running total of expenses
|
||||||
|
statement: |
|
||||||
|
SELECT
|
||||||
|
expense_date,
|
||||||
|
amount,
|
||||||
|
SUM(amount) OVER (ORDER BY expense_date) as running_total
|
||||||
|
FROM expenses
|
||||||
|
WHERE user_id = $1
|
||||||
|
ORDER BY expense_date
|
||||||
|
parameters:
|
||||||
|
- name: user_id
|
||||||
|
type: string
|
||||||
|
description: The user's UUID
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common Table Expressions (CTEs)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
top_spenders:
|
||||||
|
type: cockroachdb-sql
|
||||||
|
source: my_cockroachdb
|
||||||
|
description: Find top spending users
|
||||||
|
statement: |
|
||||||
|
WITH user_totals AS (
|
||||||
|
SELECT
|
||||||
|
user_id,
|
||||||
|
SUM(amount) as total_spent
|
||||||
|
FROM expenses
|
||||||
|
WHERE expense_date >= $1
|
||||||
|
GROUP BY user_id
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
u.name,
|
||||||
|
u.email,
|
||||||
|
ut.total_spent
|
||||||
|
FROM user_totals ut
|
||||||
|
JOIN users u ON ut.user_id = u.id
|
||||||
|
ORDER BY ut.total_spent DESC
|
||||||
|
LIMIT 10
|
||||||
|
parameters:
|
||||||
|
- name: start_date
|
||||||
|
type: string
|
||||||
|
description: Start date in YYYY-MM-DD format
|
||||||
|
```
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
|
||||||
|
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - For ad-hoc SQL execution
|
||||||
|
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||||
|
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||||
|
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||||
@@ -771,7 +771,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"version = \"0.26.0\" # x-release-please-version\n",
|
"version = \"0.27.0\" # x-release-please-version\n",
|
||||||
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Make the binary executable\n",
|
"# Make the binary executable\n",
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
export VERSION="0.26.0"
|
export VERSION="0.27.0"
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|||||||
@@ -220,7 +220,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"version = \"0.26.0\" # x-release-please-version\n",
|
"version = \"0.27.0\" # x-release-please-version\n",
|
||||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Make the binary executable\n",
|
"# Make the binary executable\n",
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
|||||||
<!-- {x-release-please-start-version} -->
|
<!-- {x-release-please-start-version} -->
|
||||||
```bash
|
```bash
|
||||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||||
```
|
```
|
||||||
<!-- {x-release-please-end} -->
|
<!-- {x-release-please-end} -->
|
||||||
|
|
||||||
|
|||||||
7
docs/en/samples/oracle/_index.md
Normal file
7
docs/en/samples/oracle/_index.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
---
|
||||||
|
title: "OracleDB"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
How to get started with Toolbox using Oracle Database.
|
||||||
|
---
|
||||||
@@ -16,14 +16,7 @@ This guide demonstrates how to implement these patterns in your Toolbox applicat
|
|||||||
|
|
||||||
{{< tabpane persist=header >}}
|
{{< tabpane persist=header >}}
|
||||||
{{% tab header="ADK" text=true %}}
|
{{% tab header="ADK" text=true %}}
|
||||||
The following example demonstrates how to use `ToolboxToolset` with ADK's pre and post processing hooks to implement pre and post processing for tool calls.
|
Coming soon.
|
||||||
|
|
||||||
```py
|
|
||||||
{{< include "python/adk/agent.py" >}}
|
|
||||||
```
|
|
||||||
You can also add model-level (`before_model_callback`, `after_model_callback`) and agent-level (`before_agent_callback`, `after_agent_callback`) hooks to intercept messages at different stages of the execution loop.
|
|
||||||
|
|
||||||
For more information, see the [ADK Callbacks documentation](https://google.github.io/adk-docs/callbacks/types-of-callbacks/).
|
|
||||||
{{% /tab %}}
|
{{% /tab %}}
|
||||||
{{% tab header="Langchain" text=true %}}
|
{{% tab header="Langchain" text=true %}}
|
||||||
The following example demonstrates how to use `ToolboxClient` with LangChain's middleware to implement pre- and post- processing for tool calls.
|
The following example demonstrates how to use `ToolboxClient` with LangChain's middleware to implement pre- and post- processing for tool calls.
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from google.adk import Agent
|
|
||||||
from google.adk.apps import App
|
|
||||||
from google.adk.runners import Runner
|
|
||||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
|
||||||
from google.adk.tools.tool_context import ToolContext
|
|
||||||
from google.genai import types
|
|
||||||
from toolbox_adk import CredentialStrategy, ToolboxToolset, ToolboxTool
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """
|
|
||||||
You're a helpful hotel assistant. You handle hotel searching, booking and
|
|
||||||
cancellations. When the user searches for a hotel, mention it's name, id,
|
|
||||||
location and price tier. Always mention hotel ids while performing any
|
|
||||||
searches. This is very important for any operations. For any bookings or
|
|
||||||
cancellations, please provide the appropriate confirmation. Be sure to
|
|
||||||
update checkin or checkout dates if mentioned by the user.
|
|
||||||
Don't ask for confirmations from the user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# Pre processing
|
|
||||||
async def before_tool_callback(
|
|
||||||
tool: ToolboxTool, args: Dict[str, Any], tool_context: ToolContext
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Callback fired before a tool is executed.
|
|
||||||
Enforces business logic: Max stay duration is 14 days.
|
|
||||||
"""
|
|
||||||
tool_name = tool.name
|
|
||||||
print(f"POLICY CHECK: Intercepting '{tool_name}'")
|
|
||||||
|
|
||||||
if tool_name == "update-hotel" and "checkin_date" in args and "checkout_date" in args:
|
|
||||||
start = datetime.fromisoformat(args["checkin_date"])
|
|
||||||
end = datetime.fromisoformat(args["checkout_date"])
|
|
||||||
duration = (end - start).days
|
|
||||||
|
|
||||||
if duration > 14:
|
|
||||||
print("BLOCKED: Stay too long")
|
|
||||||
return {"result": "Error: Maximum stay duration is 14 days."}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# Post processing
|
|
||||||
async def after_tool_callback(
|
|
||||||
tool: ToolboxTool,
|
|
||||||
args: Dict[str, Any],
|
|
||||||
tool_context: ToolContext,
|
|
||||||
tool_response: Any,
|
|
||||||
) -> Optional[Any]:
|
|
||||||
"""
|
|
||||||
Callback fired after a tool execution.
|
|
||||||
Enriches response for successful bookings.
|
|
||||||
"""
|
|
||||||
if isinstance(tool_response, dict):
|
|
||||||
result = tool_response.get("result", "")
|
|
||||||
elif isinstance(tool_response, str):
|
|
||||||
result = tool_response
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_name = tool.name
|
|
||||||
if isinstance(result, str) and "Error" not in result:
|
|
||||||
if tool_name == "book-hotel":
|
|
||||||
loyalty_bonus = 500
|
|
||||||
enriched_result = f"Booking Confirmed!\n You earned {loyalty_bonus} Loyalty Points with this stay.\n\nSystem Details: {result}"
|
|
||||||
|
|
||||||
if isinstance(tool_response, dict):
|
|
||||||
modified_response = deepcopy(tool_response)
|
|
||||||
modified_response["result"] = enriched_result
|
|
||||||
return modified_response
|
|
||||||
else:
|
|
||||||
return enriched_result
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def run_chat_turn(
|
|
||||||
runner: Runner, session_id: str, user_id: str, message_text: str
|
|
||||||
):
|
|
||||||
"""Executes a single chat turn and prints the interaction."""
|
|
||||||
print(f"\nUSER: '{message_text}'")
|
|
||||||
response_text = ""
|
|
||||||
async for event in runner.run_async(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
new_message=types.Content(role="user", parts=[types.Part(text=message_text)]),
|
|
||||||
):
|
|
||||||
if event.content and event.content.parts:
|
|
||||||
for part in event.content.parts:
|
|
||||||
if part.text:
|
|
||||||
response_text += part.text
|
|
||||||
|
|
||||||
print(f"AI: {response_text}")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
toolset = ToolboxToolset(
|
|
||||||
server_url="http://127.0.0.1:5000",
|
|
||||||
toolset_name="my-toolset",
|
|
||||||
credentials=CredentialStrategy.toolbox_identity(),
|
|
||||||
)
|
|
||||||
tools = await toolset.get_tools()
|
|
||||||
root_agent = Agent(
|
|
||||||
name="root_agent",
|
|
||||||
model="gemini-2.5-flash",
|
|
||||||
instruction=SYSTEM_PROMPT,
|
|
||||||
tools=tools,
|
|
||||||
# add any pre and post processing callbacks
|
|
||||||
before_tool_callback=before_tool_callback,
|
|
||||||
after_tool_callback=after_tool_callback,
|
|
||||||
)
|
|
||||||
app = App(root_agent=root_agent, name="my_agent")
|
|
||||||
runner = Runner(app=app, session_service=InMemorySessionService())
|
|
||||||
session_id = "test-session"
|
|
||||||
user_id = "test-user"
|
|
||||||
await runner.session_service.create_session(
|
|
||||||
app_name=app.name, user_id=user_id, session_id=session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# First turn: Successful booking
|
|
||||||
await run_chat_turn(runner, session_id, user_id, "Book hotel with id 3.")
|
|
||||||
print("-" * 50)
|
|
||||||
# Second turn: Policy violation (stay > 14 days)
|
|
||||||
await run_chat_turn(
|
|
||||||
runner,
|
|
||||||
session_id,
|
|
||||||
user_id,
|
|
||||||
"Book a hotel with id 5 with checkin date 2025-01-18 and checkout date 2025-02-10",
|
|
||||||
)
|
|
||||||
await toolset.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
google-adk[toolbox]==1.23.0
|
|
||||||
toolbox-adk==0.5.8
|
|
||||||
google-genai==1.62.0
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "mcp-toolbox-for-databases",
|
"name": "mcp-toolbox-for-databases",
|
||||||
"version": "0.26.0",
|
"version": "0.27.0",
|
||||||
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
|
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
|
||||||
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
|
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
|
||||||
}
|
}
|
||||||
1
go.mod
1
go.mod
@@ -21,6 +21,7 @@ require (
|
|||||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.30.0
|
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.30.0
|
||||||
github.com/apache/cassandra-gocql-driver/v2 v2.0.0
|
github.com/apache/cassandra-gocql-driver/v2 v2.0.0
|
||||||
github.com/cenkalti/backoff/v5 v5.0.3
|
github.com/cenkalti/backoff/v5 v5.0.3
|
||||||
|
github.com/cockroachdb/cockroach-go/v2 v2.4.2
|
||||||
github.com/couchbase/gocb/v2 v2.11.1
|
github.com/couchbase/gocb/v2 v2.11.1
|
||||||
github.com/couchbase/tools-common/http v1.0.9
|
github.com/couchbase/tools-common/http v1.0.9
|
||||||
github.com/elastic/elastic-transport-go/v8 v8.8.0
|
github.com/elastic/elastic-transport-go/v8 v8.8.0
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -800,6 +800,8 @@ github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWH
|
|||||||
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||||
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls=
|
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls=
|
||||||
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
||||||
|
github.com/cockroachdb/cockroach-go/v2 v2.4.2 h1:QB0ozDWQUUJ0GP8Zw63X/qHefPTCpLvtfCs6TLrPgyE=
|
||||||
|
github.com/cockroachdb/cockroach-go/v2 v2.4.2/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0=
|
||||||
github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4=
|
github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4=
|
||||||
github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE=
|
github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE=
|
||||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
@@ -960,6 +962,8 @@ github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM=
|
|||||||
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||||
|
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
|
||||||
|
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ var expectedToolSources = []string{
|
|||||||
"mysql",
|
"mysql",
|
||||||
"neo4j",
|
"neo4j",
|
||||||
"oceanbase",
|
"oceanbase",
|
||||||
|
"oracledb",
|
||||||
"postgres",
|
"postgres",
|
||||||
"serverless-spark",
|
"serverless-spark",
|
||||||
"singlestore",
|
"singlestore",
|
||||||
@@ -131,6 +132,8 @@ func TestGetPrebuiltTool(t *testing.T) {
|
|||||||
neo4jconfig := getOrFatal(t, "neo4j")
|
neo4jconfig := getOrFatal(t, "neo4j")
|
||||||
healthcare_config := getOrFatal(t, "cloud-healthcare")
|
healthcare_config := getOrFatal(t, "cloud-healthcare")
|
||||||
snowflake_config := getOrFatal(t, "snowflake")
|
snowflake_config := getOrFatal(t, "snowflake")
|
||||||
|
oracle_config := getOrFatal(t,"oracledb")
|
||||||
|
|
||||||
if len(alloydb_omni_config) <= 0 {
|
if len(alloydb_omni_config) <= 0 {
|
||||||
t.Fatalf("unexpected error: could not fetch alloydb omni prebuilt tools yaml")
|
t.Fatalf("unexpected error: could not fetch alloydb omni prebuilt tools yaml")
|
||||||
}
|
}
|
||||||
@@ -230,6 +233,10 @@ func TestGetPrebuiltTool(t *testing.T) {
|
|||||||
if len(snowflake_config) <= 0 {
|
if len(snowflake_config) <= 0 {
|
||||||
t.Fatalf("unexpected error: could not fetch snowflake prebuilt tools yaml")
|
t.Fatalf("unexpected error: could not fetch snowflake prebuilt tools yaml")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(oracle_config) <= 0 {
|
||||||
|
t.Fatalf("unexpected error: could not fetch oracle prebuilt tools yaml")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFailGetPrebuiltTool(t *testing.T) {
|
func TestFailGetPrebuiltTool(t *testing.T) {
|
||||||
|
|||||||
121
internal/prebuiltconfigs/tools/oracledb.yaml
Normal file
121
internal/prebuiltconfigs/tools/oracledb.yaml
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2026 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
sources:
|
||||||
|
oracle-source:
|
||||||
|
kind: "oracle"
|
||||||
|
connectionString: ${ORACLE_CONNECTION_STRING}
|
||||||
|
walletLocation: ${ORACLE_WALLET:}
|
||||||
|
user: ${ORACLE_USER}
|
||||||
|
password: ${ORACLE_PASSWORD}
|
||||||
|
useOCI: ${ORACLE_USE_OCI:false}
|
||||||
|
|
||||||
|
tools:
|
||||||
|
|
||||||
|
list_tables:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "Lists all user tables in the connected schema, including segment size, row count, and last analyzed date. Filters by a comma-separated list of names. If names are omitted, lists all tables in the current user's schema."
|
||||||
|
statement: SELECT table_name from user_tables;
|
||||||
|
|
||||||
|
list_active_sessions:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "List the top N (default 50) currently running database sessions (STATUS='ACTIVE'), showing SID, OS User, Program, and the current SQL statement text."
|
||||||
|
statement: SELECT
|
||||||
|
s.sid,
|
||||||
|
s.serial#,
|
||||||
|
s.username,
|
||||||
|
s.osuser,
|
||||||
|
s.program,
|
||||||
|
s.status,
|
||||||
|
s.wait_class,
|
||||||
|
s.event,
|
||||||
|
sql.sql_text
|
||||||
|
FROM
|
||||||
|
v$session s,
|
||||||
|
v$sql sql
|
||||||
|
WHERE
|
||||||
|
s.status = 'ACTIVE'
|
||||||
|
AND s.sql_id = sql.sql_id (+)
|
||||||
|
AND s.audsid != userenv('sessionid') -- Exclude current session
|
||||||
|
ORDER BY s.last_call_et DESC
|
||||||
|
FETCH FIRST COALESCE(10) ROWS ONLY;
|
||||||
|
|
||||||
|
get_query_plan:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "Generate a full execution plan for a single SQL statement. This can be used to analyze query performance without execution. Requires the SQL statement as input. following is an example EXPLAIN PLAN FOR {{&query}};"
|
||||||
|
statement: SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY());
|
||||||
|
|
||||||
|
list_top_sql_by_resource:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "List the top N SQL statements from the library cache based on a chosen resource metric (CPU, I/O, or Elapsed Time), following is an example of the sql"
|
||||||
|
statement: SELECT
|
||||||
|
sql_id,
|
||||||
|
executions,
|
||||||
|
buffer_gets,
|
||||||
|
disk_reads,
|
||||||
|
cpu_time / 1000000 AS cpu_seconds,
|
||||||
|
elapsed_time / 1000000 AS elapsed_seconds
|
||||||
|
FROM
|
||||||
|
v$sql
|
||||||
|
FETCH FIRST 5 ROWS ONLY;
|
||||||
|
|
||||||
|
list_tablespace_usage:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "List tablespace names, total size, free space, and used percentage to monitor storage utilization."
|
||||||
|
statement: SELECT
|
||||||
|
t.tablespace_name,
|
||||||
|
TO_CHAR(t.total_bytes / 1024 / 1024, '99,999.00') AS total_mb,
|
||||||
|
TO_CHAR(SUM(d.bytes) / 1024 / 1024, '99,999.00') AS free_mb,
|
||||||
|
TO_CHAR((t.total_bytes - SUM(d.bytes)) / t.total_bytes * 100, '99.00') AS used_pct
|
||||||
|
FROM
|
||||||
|
(SELECT tablespace_name, SUM(bytes) AS total_bytes FROM dba_data_files GROUP BY tablespace_name) t,
|
||||||
|
dba_free_space d
|
||||||
|
WHERE
|
||||||
|
t.tablespace_name = d.tablespace_name (+)
|
||||||
|
GROUP BY
|
||||||
|
t.tablespace_name, t.total_bytes
|
||||||
|
ORDER BY
|
||||||
|
used_pct DESC;
|
||||||
|
|
||||||
|
list_invalid_objects:
|
||||||
|
kind: oracle-sql
|
||||||
|
source: oracle-source
|
||||||
|
description: "Lists all database objects that are in an invalid state, requiring recompilation (e.g., procedures, functions, views)."
|
||||||
|
statement: SELECT
|
||||||
|
owner,
|
||||||
|
object_type,
|
||||||
|
object_name,
|
||||||
|
status
|
||||||
|
FROM
|
||||||
|
dba_objects
|
||||||
|
WHERE
|
||||||
|
status = 'INVALID'
|
||||||
|
AND owner NOT IN ('SYS', 'SYSTEM') -- Exclude system schemas for clarity
|
||||||
|
ORDER BY
|
||||||
|
owner, object_type, object_name;
|
||||||
|
|
||||||
|
toolsets:
|
||||||
|
oracle_database_tools:
|
||||||
|
- execute_sql
|
||||||
|
- list_tables
|
||||||
|
- list_active_sessions
|
||||||
|
- get_query_plan
|
||||||
|
- list_top_sql_by_resource
|
||||||
|
- list_tablespace_usage
|
||||||
|
- list_invalid_objects
|
||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
@@ -216,7 +215,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
|
err = fmt.Errorf("tool invocation not authorized. Please make sure you specify correct auth headers")
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
@@ -234,15 +233,28 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If auth error, return 401
|
var clientServerErr *util.ClientServerError
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
// Return 401 Authentication errors
|
||||||
|
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
|
||||||
|
s.logger.DebugContext(ctx, fmt.Sprintf("auth error: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
var agentErr *util.AgentError
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
if errors.As(err, &agentErr) {
|
||||||
|
s.logger.DebugContext(ctx, fmt.Sprintf("agent validation error: %v", err))
|
||||||
|
errMap := map[string]string{"error": err.Error()}
|
||||||
|
errMarshal, _ := json.Marshal(errMap)
|
||||||
|
|
||||||
|
_ = render.Render(w, r, &resultResponse{Result: string(errMarshal)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return 500 if it's a specific ClientServerError that isn't a 401, or any other unexpected error
|
||||||
|
s.logger.ErrorContext(ctx, fmt.Sprintf("internal server error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
@@ -259,35 +271,51 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Determine what error to return to the users.
|
// Determine what error to return to the users.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
var statusCode int
|
|
||||||
|
|
||||||
// Upstream API auth error propagation
|
if errors.As(err, &tbErr) {
|
||||||
switch {
|
switch tbErr.Category() {
|
||||||
case strings.Contains(errStr, "Error 401"):
|
case util.CategoryAgent:
|
||||||
statusCode = http.StatusUnauthorized
|
// Agent Errors -> 200 OK
|
||||||
case strings.Contains(errStr, "Error 403"):
|
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
||||||
statusCode = http.StatusForbidden
|
res = map[string]string{
|
||||||
|
"error": err.Error(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// Server Errors -> Check the specific code inside
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
statusCode := http.StatusInternalServerError // Default to 500
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code != 0 {
|
||||||
|
statusCode = clientServerErr.Code
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process auth error
|
||||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
// Propagate the original 401/403 error.
|
// Token error, pass through 401/403
|
||||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
||||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// ADC lacking permission or credentials configuration error.
|
// ADC/Config error, return 500
|
||||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
statusCode = http.StatusInternalServerError
|
||||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
}
|
||||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
|
||||||
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
} else {
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
// Unknown error -> 500
|
||||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
||||||
|
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
resMarshal, err := json.Marshal(res)
|
resMarshal, err := json.Marshal(res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -37,9 +36,11 @@ import (
|
|||||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/attribute"
|
"go.opentelemetry.io/otel/attribute"
|
||||||
"go.opentelemetry.io/otel/codes"
|
"go.opentelemetry.io/otel/codes"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
type sseSession struct {
|
type sseSession struct {
|
||||||
@@ -117,6 +118,55 @@ type stdioSession struct {
|
|||||||
writer io.Writer
|
writer io.Writer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// traceContextCarrier implements propagation.TextMapCarrier for extracting trace context from _meta
|
||||||
|
type traceContextCarrier map[string]string
|
||||||
|
|
||||||
|
func (c traceContextCarrier) Get(key string) string {
|
||||||
|
return c[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c traceContextCarrier) Set(key, value string) {
|
||||||
|
c[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c traceContextCarrier) Keys() []string {
|
||||||
|
keys := make([]string, 0, len(c))
|
||||||
|
for k := range c {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractTraceContext extracts W3C Trace Context from params._meta
|
||||||
|
func extractTraceContext(ctx context.Context, body []byte) context.Context {
|
||||||
|
// Try to parse the request to extract _meta
|
||||||
|
var req struct {
|
||||||
|
Params struct {
|
||||||
|
Meta struct {
|
||||||
|
Traceparent string `json:"traceparent,omitempty"`
|
||||||
|
Tracestate string `json:"tracestate,omitempty"`
|
||||||
|
} `json:"_meta,omitempty"`
|
||||||
|
} `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
// If traceparent is present, extract the context
|
||||||
|
if req.Params.Meta.Traceparent != "" {
|
||||||
|
carrier := traceContextCarrier{
|
||||||
|
"traceparent": req.Params.Meta.Traceparent,
|
||||||
|
}
|
||||||
|
if req.Params.Meta.Tracestate != "" {
|
||||||
|
carrier["tracestate"] = req.Params.Meta.Tracestate
|
||||||
|
}
|
||||||
|
return otel.GetTextMapPropagator().Extract(ctx, carrier)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
|
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
|
||||||
stdioSession := &stdioSession{
|
stdioSession := &stdioSession{
|
||||||
server: s,
|
server: s,
|
||||||
@@ -143,18 +193,29 @@ func (s *stdioSession) readInputStream(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", "", nil)
|
// This ensures the transport span becomes a child of the client span
|
||||||
|
msgCtx := extractTraceContext(ctx, []byte(line))
|
||||||
|
|
||||||
|
// Create span for STDIO transport
|
||||||
|
msgCtx, span := s.server.instrumentation.Tracer.Start(msgCtx, "toolbox/server/mcp/stdio",
|
||||||
|
trace.WithSpanKind(trace.SpanKindServer),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
v, res, err := processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// errors during the processing of message will generate a valid MCP Error response.
|
// errors during the processing of message will generate a valid MCP Error response.
|
||||||
// server can continue to run.
|
// server can continue to run.
|
||||||
s.server.logger.ErrorContext(ctx, err.Error())
|
s.server.logger.ErrorContext(msgCtx, err.Error())
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if v != "" {
|
if v != "" {
|
||||||
s.protocol = v
|
s.protocol = v
|
||||||
}
|
}
|
||||||
// no responses for notifications
|
// no responses for notifications
|
||||||
if res != nil {
|
if res != nil {
|
||||||
if err = s.write(ctx, res); err != nil {
|
if err = s.write(msgCtx, res); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -240,7 +301,9 @@ func mcpRouter(s *Server) (chi.Router, error) {
|
|||||||
|
|
||||||
// sseHandler handles sse initialization and message.
|
// sseHandler handles sse initialization and message.
|
||||||
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse")
|
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse",
|
||||||
|
trace.WithSpanKind(trace.SpanKindServer),
|
||||||
|
)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
sessionId := uuid.New().String()
|
sessionId := uuid.New().String()
|
||||||
@@ -336,9 +399,27 @@ func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
ctx := r.Context()
|
||||||
|
ctx = util.WithLogger(ctx, s.logger)
|
||||||
|
|
||||||
|
// Read body first so we can extract trace context
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
// Generate a new uuid if unable to decode
|
||||||
|
id := uuid.New().String()
|
||||||
|
s.logger.DebugContext(ctx, err.Error())
|
||||||
|
render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// This ensures the transport span becomes a child of the client span
|
||||||
|
ctx = extractTraceContext(ctx, body)
|
||||||
|
|
||||||
|
// Create span for HTTP transport
|
||||||
|
ctx, span := s.instrumentation.Tracer.Start(ctx, "toolbox/server/mcp/http",
|
||||||
|
trace.WithSpanKind(trace.SpanKindServer),
|
||||||
|
)
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
ctx = util.WithLogger(r.Context(), s.logger)
|
|
||||||
|
|
||||||
var sessionId, protocolVersion string
|
var sessionId, protocolVersion string
|
||||||
var session *sseSession
|
var session *sseSession
|
||||||
@@ -380,7 +461,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||||
|
|
||||||
var err error
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
span.SetStatus(codes.Error, err.Error())
|
span.SetStatus(codes.Error, err.Error())
|
||||||
@@ -399,17 +479,9 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Read and returns a body from io.Reader
|
networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor)
|
||||||
body, err := io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
// Generate a new uuid if unable to decode
|
|
||||||
id := uuid.New().String()
|
|
||||||
s.logger.DebugContext(ctx, err.Error())
|
|
||||||
render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header)
|
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header, networkProtocolVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error())
|
s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error())
|
||||||
}
|
}
|
||||||
@@ -444,15 +516,12 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
code := rpcResponse.Error.Code
|
code := rpcResponse.Error.Code
|
||||||
switch code {
|
switch code {
|
||||||
case jsonrpc.INTERNAL_ERROR:
|
case jsonrpc.INTERNAL_ERROR:
|
||||||
|
// Map Internal RPC Error (-32603) to HTTP 500
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
case jsonrpc.INVALID_REQUEST:
|
case jsonrpc.INVALID_REQUEST:
|
||||||
errStr := err.Error()
|
var clientServerErr *util.ClientServerError
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &clientServerErr) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(clientServerErr.Code)
|
||||||
} else if strings.Contains(errStr, "Error 401") {
|
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
|
||||||
} else if strings.Contains(errStr, "Error 403") {
|
|
||||||
w.WriteHeader(http.StatusForbidden)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -462,7 +531,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// processMcpMessage process the messages received from clients
|
// processMcpMessage process the messages received from clients
|
||||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header) (string, any, error) {
|
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header, networkProtocolVersion string) (string, any, error) {
|
||||||
logger, err := util.LoggerFromContext(ctx)
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
@@ -498,31 +567,95 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
|
|||||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create method-specific span with semantic conventions
|
||||||
|
// Note: Trace context is already extracted and set in ctx by the caller
|
||||||
|
ctx, span := s.instrumentation.Tracer.Start(ctx, baseMessage.Method,
|
||||||
|
trace.WithSpanKind(trace.SpanKindServer),
|
||||||
|
)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
// Determine network transport and protocol based on header presence
|
||||||
|
networkTransport := "pipe" // default for stdio
|
||||||
|
networkProtocolName := "stdio"
|
||||||
|
if header != nil {
|
||||||
|
networkTransport = "tcp" // HTTP/SSE transport
|
||||||
|
networkProtocolName = "http"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set required semantic attributes for span according to OTEL MCP semcov
|
||||||
|
// ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/mcp/#server
|
||||||
|
span.SetAttributes(
|
||||||
|
attribute.String("mcp.method.name", baseMessage.Method),
|
||||||
|
attribute.String("network.transport", networkTransport),
|
||||||
|
attribute.String("network.protocol.name", networkProtocolName),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set network protocol version if available
|
||||||
|
if networkProtocolVersion != "" {
|
||||||
|
span.SetAttributes(attribute.String("network.protocol.version", networkProtocolVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set MCP protocol version if available
|
||||||
|
if protocolVersion != "" {
|
||||||
|
span.SetAttributes(attribute.String("mcp.protocol.version", protocolVersion))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set request ID
|
||||||
|
if baseMessage.Id != nil {
|
||||||
|
span.SetAttributes(attribute.String("jsonrpc.request.id", fmt.Sprintf("%v", baseMessage.Id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set toolset name
|
||||||
|
span.SetAttributes(attribute.String("toolset.name", toolsetName))
|
||||||
|
|
||||||
// Check if message is a notification
|
// Check if message is a notification
|
||||||
if baseMessage.Id == nil {
|
if baseMessage.Id == nil {
|
||||||
err := mcp.NotificationHandler(ctx, body)
|
err := mcp.NotificationHandler(ctx, body)
|
||||||
|
if err != nil {
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
}
|
||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process the method
|
||||||
switch baseMessage.Method {
|
switch baseMessage.Method {
|
||||||
case mcputil.INITIALIZE:
|
case mcputil.INITIALIZE:
|
||||||
res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
|
result, version, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", res, err
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
|
||||||
|
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||||
}
|
}
|
||||||
return v, res, err
|
return "", result, err
|
||||||
|
}
|
||||||
|
span.SetAttributes(attribute.String("mcp.protocol.version", version))
|
||||||
|
return version, result, err
|
||||||
default:
|
default:
|
||||||
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("toolset does not exist")
|
err := fmt.Errorf("toolset does not exist")
|
||||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||||
|
return "", rpcErr, err
|
||||||
}
|
}
|
||||||
promptset, ok := s.ResourceMgr.GetPromptset(promptsetName)
|
promptset, ok := s.ResourceMgr.GetPromptset(promptsetName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("promptset does not exist")
|
err := fmt.Errorf("promptset does not exist")
|
||||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||||
|
return "", rpcErr, err
|
||||||
}
|
}
|
||||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header)
|
result, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header)
|
||||||
return "", res, err
|
if err != nil {
|
||||||
|
span.SetStatus(codes.Error, err.Error())
|
||||||
|
// Set error.type based on JSON-RPC error code
|
||||||
|
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
|
||||||
|
span.SetAttributes(attribute.Int("jsonrpc.error.code", rpcErr.Error.Code))
|
||||||
|
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", result, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,6 +45,9 @@ type Request struct {
|
|||||||
// notifications. The receiver is not obligated to provide these
|
// notifications. The receiver is not obligated to provide these
|
||||||
// notifications.
|
// notifications.
|
||||||
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
||||||
|
// W3C Trace Context fields for distributed tracing
|
||||||
|
Traceparent string `json:"traceparent,omitempty"`
|
||||||
|
Tracestate string `json:"tracestate,omitempty"`
|
||||||
} `json:"_meta,omitempty"`
|
} `json:"_meta,omitempty"`
|
||||||
} `json:"params,omitempty"`
|
} `json:"params,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -97,6 +100,24 @@ type Error struct {
|
|||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// String returns the error type as a string based on the error code.
|
||||||
|
func (e Error) String() string {
|
||||||
|
switch e.Code {
|
||||||
|
case METHOD_NOT_FOUND:
|
||||||
|
return "method_not_found"
|
||||||
|
case INVALID_PARAMS:
|
||||||
|
return "invalid_params"
|
||||||
|
case INTERNAL_ERROR:
|
||||||
|
return "internal_error"
|
||||||
|
case PARSE_ERROR:
|
||||||
|
return "parse_error"
|
||||||
|
case INVALID_REQUEST:
|
||||||
|
return "invalid_request"
|
||||||
|
default:
|
||||||
|
return "jsonrpc_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// JSONRPCError represents a non-successful (error) response to a request.
|
// JSONRPCError represents a non-successful (error) response to a request.
|
||||||
type JSONRPCError struct {
|
type JSONRPCError struct {
|
||||||
Jsonrpc string `json:"jsonrpc"`
|
Jsonrpc string `json:"jsonrpc"`
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -29,6 +28,8 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessMethod returns a response for the request.
|
// ProcessMethod returns a response for the request.
|
||||||
@@ -102,6 +103,14 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
toolName := req.Params.Name
|
toolName := req.Params.Name
|
||||||
toolArgument := req.Params.Arguments
|
toolArgument := req.Params.Arguments
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||||
|
span.SetAttributes(
|
||||||
|
attribute.String("gen_ai.tool.name", toolName),
|
||||||
|
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||||
|
)
|
||||||
tool, ok := resourceMgr.GetTool(toolName)
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
@@ -124,7 +133,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
err := util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +186,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,21 +212,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Upstream auth error
|
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if errors.As(err, &tbErr) {
|
||||||
|
switch tbErr.Category() {
|
||||||
|
case util.CategoryAgent:
|
||||||
|
// MCP - Tool execution error
|
||||||
|
// Return SUCCESS but with IsError: true
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -218,6 +228,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
@@ -288,6 +320,11 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
|||||||
|
|
||||||
promptName := req.Params.Name
|
promptName := req.Params.Name
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||||
|
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -29,6 +28,8 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessMethod returns a response for the request.
|
// ProcessMethod returns a response for the request.
|
||||||
@@ -102,6 +103,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
toolName := req.Params.Name
|
toolName := req.Params.Name
|
||||||
toolArgument := req.Params.Arguments
|
toolArgument := req.Params.Arguments
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||||
|
span.SetAttributes(
|
||||||
|
attribute.String("gen_ai.tool.name", toolName),
|
||||||
|
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||||
|
)
|
||||||
|
|
||||||
tool, ok := resourceMgr.GetTool(toolName)
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
@@ -124,7 +134,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
err := util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -172,7 +187,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -194,20 +213,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -217,8 +229,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
}
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
sliceRes, ok := results.([]any)
|
sliceRes, ok := results.([]any)
|
||||||
@@ -287,6 +320,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
|||||||
|
|
||||||
promptName := req.Params.Name
|
promptName := req.Params.Name
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||||
|
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||||
|
|
||||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -29,6 +28,8 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessMethod returns a response for the request.
|
// ProcessMethod returns a response for the request.
|
||||||
@@ -95,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
toolName := req.Params.Name
|
toolName := req.Params.Name
|
||||||
toolArgument := req.Params.Arguments
|
toolArgument := req.Params.Arguments
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||||
|
span.SetAttributes(
|
||||||
|
attribute.String("gen_ai.tool.name", toolName),
|
||||||
|
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||||
|
)
|
||||||
|
|
||||||
tool, ok := resourceMgr.GetTool(toolName)
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
@@ -117,7 +127,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
err := util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +180,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +206,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +222,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
@@ -280,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
|||||||
|
|
||||||
promptName := req.Params.Name
|
promptName := req.Params.Name
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||||
|
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||||
|
|
||||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
@@ -29,6 +28,8 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProcessMethod returns a response for the request.
|
// ProcessMethod returns a response for the request.
|
||||||
@@ -95,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
toolName := req.Params.Name
|
toolName := req.Params.Name
|
||||||
toolArgument := req.Params.Arguments
|
toolArgument := req.Params.Arguments
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||||
|
span.SetAttributes(
|
||||||
|
attribute.String("gen_ai.tool.name", toolName),
|
||||||
|
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||||
|
)
|
||||||
|
|
||||||
tool, ok := resourceMgr.GetTool(toolName)
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
@@ -117,7 +127,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
if clientAuth {
|
if clientAuth {
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
err := util.NewClientServerError(
|
||||||
|
"missing access token in the 'Authorization' header",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,7 +180,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// Check if any of the specified auth services is verified
|
// Check if any of the specified auth services is verified
|
||||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
if !isAuthorized {
|
if !isAuthorized {
|
||||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
err = util.NewClientServerError(
|
||||||
|
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
|
http.StatusUnauthorized,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, "tool invocation authorized")
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
@@ -187,20 +206,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errStr := err.Error()
|
var tbErr util.ToolboxError
|
||||||
// Missing authService tokens.
|
|
||||||
if errors.Is(err, util.ErrUnauthorized) {
|
if errors.As(err, &tbErr) {
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
switch tbErr.Category() {
|
||||||
}
|
case util.CategoryAgent:
|
||||||
// Upstream auth error
|
// MCP - Tool execution error
|
||||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
// Return SUCCESS but with IsError: true
|
||||||
if clientAuth {
|
|
||||||
// Error with client credentials should pass down to the client
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
// Auth error with ADC should raise internal 500 error
|
|
||||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
|
||||||
}
|
|
||||||
text := TextContent{
|
text := TextContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: err.Error(),
|
Text: err.Error(),
|
||||||
@@ -210,6 +222,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
Id: id,
|
Id: id,
|
||||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case util.CategoryServer:
|
||||||
|
// MCP Spec - Protocol error
|
||||||
|
// Return JSON-RPC ERROR
|
||||||
|
var clientServerErr *util.ClientServerError
|
||||||
|
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||||
|
|
||||||
|
if errors.As(err, &clientServerErr) {
|
||||||
|
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||||
|
if clientAuth {
|
||||||
|
rpcCode = jsonrpc.INVALID_REQUEST
|
||||||
|
} else {
|
||||||
|
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Unknown error -> 500
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
content := make([]TextContent, 0)
|
content := make([]TextContent, 0)
|
||||||
@@ -280,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
|||||||
|
|
||||||
promptName := req.Params.Name
|
promptName := req.Params.Name
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||||
|
|
||||||
|
// Update span name and set gen_ai attributes
|
||||||
|
span := trace.SpanFromContext(ctx)
|
||||||
|
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||||
|
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||||
|
|
||||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
"message": "unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -320,7 +320,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
|||||||
Params: map[string]any{
|
Params: map[string]any{
|
||||||
"name": "prompt2",
|
"name": "prompt2",
|
||||||
"arguments": map[string]any{
|
"arguments": map[string]any{
|
||||||
"arg1": 42, // prompt2 expects a string, we send a number
|
"arg1": 42,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
"id": "tools-call-tool4",
|
"id": "tools-call-tool4",
|
||||||
"error": map[string]any{
|
"error": map[string]any{
|
||||||
"code": -32600.0,
|
"code": -32600.0,
|
||||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
"message": "unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,7 +35,7 @@ type MockTool struct {
|
|||||||
requiresClientAuthrorization bool
|
requiresClientAuthrorization bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, util.ToolboxError) {
|
||||||
mock := []any{t.Name}
|
mock := []any{t.Name}
|
||||||
return mock, nil
|
return mock, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -361,7 +361,11 @@ func (s *Source) GetOperations(ctx context.Context, project, location, operation
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(opBytes), nil
|
var result any
|
||||||
|
if err := json.Unmarshal(opBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal operation bytes: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay))
|
logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay))
|
||||||
}
|
}
|
||||||
|
|||||||
421
internal/sources/cockroachdb/cockroachdb.go
Normal file
421
internal/sources/cockroachdb/cockroachdb.go
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cockroachdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"math"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
crdbpgx "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5"
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
const SourceType string = "cockroachdb"
|
||||||
|
|
||||||
|
var _ sources.SourceConfig = Config{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if !sources.Register(SourceType, newConfig) {
|
||||||
|
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||||
|
// MCP compliance: Read-only by default, require explicit opt-in for writes
|
||||||
|
actual := Config{
|
||||||
|
Name: name,
|
||||||
|
MaxRetries: 5,
|
||||||
|
RetryBaseDelay: "500ms",
|
||||||
|
ReadOnlyMode: true, // MCP requirement: read-only by default
|
||||||
|
EnableWriteMode: false, // Must be explicitly enabled
|
||||||
|
MaxRowLimit: 1000, // MCP requirement: limit query results
|
||||||
|
QueryTimeoutSec: 30, // MCP requirement: prevent long-running queries
|
||||||
|
EnableTelemetry: true, // MCP requirement: observability
|
||||||
|
TelemetryVerbose: false,
|
||||||
|
}
|
||||||
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security validation: If EnableWriteMode is true, ReadOnlyMode should be false
|
||||||
|
if actual.EnableWriteMode {
|
||||||
|
actual.ReadOnlyMode = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Name string `yaml:"name" validate:"required"`
|
||||||
|
Type string `yaml:"type" validate:"required"`
|
||||||
|
Host string `yaml:"host" validate:"required"`
|
||||||
|
Port string `yaml:"port" validate:"required"`
|
||||||
|
User string `yaml:"user" validate:"required"`
|
||||||
|
Password string `yaml:"password"`
|
||||||
|
Database string `yaml:"database" validate:"required"`
|
||||||
|
QueryParams map[string]string `yaml:"queryParams"`
|
||||||
|
MaxRetries int `yaml:"maxRetries"`
|
||||||
|
RetryBaseDelay string `yaml:"retryBaseDelay"`
|
||||||
|
|
||||||
|
// MCP Security Features
|
||||||
|
ReadOnlyMode bool `yaml:"readOnlyMode"` // Default: true (enforced in Initialize)
|
||||||
|
EnableWriteMode bool `yaml:"enableWriteMode"` // Explicit opt-in for write operations
|
||||||
|
MaxRowLimit int `yaml:"maxRowLimit"` // Default: 1000
|
||||||
|
QueryTimeoutSec int `yaml:"queryTimeoutSec"` // Default: 30
|
||||||
|
|
||||||
|
// Observability
|
||||||
|
EnableTelemetry bool `yaml:"enableTelemetry"` // Default: true
|
||||||
|
TelemetryVerbose bool `yaml:"telemetryVerbose"` // Default: false
|
||||||
|
ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Config) SourceConfigType() string {
|
||||||
|
return SourceType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||||
|
retryBaseDelay, err := time.ParseDuration(r.RetryBaseDelay)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid retryBaseDelay: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := initCockroachDBConnectionPoolWithRetry(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.MaxRetries, retryBaseDelay)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Source{
|
||||||
|
Config: r,
|
||||||
|
Pool: pool,
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ sources.Source = &Source{}
|
||||||
|
|
||||||
|
type Source struct {
|
||||||
|
Config
|
||||||
|
Pool *pgxpool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) SourceType() string {
|
||||||
|
return SourceType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ToConfig() sources.SourceConfig {
|
||||||
|
return s.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) CockroachDBPool() *pgxpool.Pool {
|
||||||
|
return s.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||||
|
return s.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteTxWithRetry executes a function within a transaction with automatic retry logic
|
||||||
|
// using the official CockroachDB retry mechanism from cockroach-go/v2
|
||||||
|
func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error) error {
|
||||||
|
return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{}, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query executes a query using the connection pool with MCP security enforcement.
|
||||||
|
// For read-only queries, connection-level retry is sufficient.
|
||||||
|
// For write operations requiring transaction retry, use ExecuteTxWithRetry directly.
|
||||||
|
// Note: Callers should manage context timeouts as needed.
|
||||||
|
func (s *Source) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
||||||
|
// MCP Security Check 1: Enforce write operation restrictions
|
||||||
|
if err := s.CanExecuteWrite(sql); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP Security Check 2: Apply query limits (row limit)
|
||||||
|
modifiedSQL, err := s.ApplyQueryLimits(sql)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.Pool.Query(ctx, modifiedSQL, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// MCP Security & Observability Features
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
// TelemetryEvent represents a structured telemetry event for MCP tool calls
|
||||||
|
type TelemetryEvent struct {
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
ToolName string `json:"tool_name"`
|
||||||
|
ClusterID string `json:"cluster_id"`
|
||||||
|
Database string `json:"database"`
|
||||||
|
User string `json:"user"`
|
||||||
|
SQLRedacted string `json:"sql_redacted"` // Query with values redacted
|
||||||
|
Status string `json:"status"` // "success" | "failure"
|
||||||
|
ErrorCode string `json:"error_code,omitempty"`
|
||||||
|
ErrorMsg string `json:"error_msg,omitempty"`
|
||||||
|
LatencyMs int64 `json:"latency_ms"`
|
||||||
|
RowsAffected int64 `json:"rows_affected,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StructuredError represents an MCP-compliant error with error codes
|
||||||
|
type StructuredError struct {
|
||||||
|
Code string `json:"error_code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details map[string]any `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *StructuredError) Error() string {
|
||||||
|
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCP Error Codes
|
||||||
|
const (
|
||||||
|
ErrCodeUnauthorized = "CRDB_UNAUTHORIZED"
|
||||||
|
ErrCodeReadOnlyViolation = "CRDB_READONLY_VIOLATION"
|
||||||
|
ErrCodeQueryTimeout = "CRDB_QUERY_TIMEOUT"
|
||||||
|
ErrCodeRowLimitExceeded = "CRDB_ROW_LIMIT_EXCEEDED"
|
||||||
|
ErrCodeInvalidSQL = "CRDB_INVALID_SQL"
|
||||||
|
ErrCodeConnectionFailed = "CRDB_CONNECTION_FAILED"
|
||||||
|
ErrCodeWriteModeRequired = "CRDB_WRITE_MODE_REQUIRED"
|
||||||
|
ErrCodeQueryExecutionFailed = "CRDB_QUERY_EXECUTION_FAILED"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLStatementType represents the type of SQL statement
|
||||||
|
type SQLStatementType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
SQLTypeUnknown SQLStatementType = iota
|
||||||
|
SQLTypeSelect
|
||||||
|
SQLTypeInsert
|
||||||
|
SQLTypeUpdate
|
||||||
|
SQLTypeDelete
|
||||||
|
SQLTypeDDL // CREATE, ALTER, DROP
|
||||||
|
SQLTypeTruncate
|
||||||
|
SQLTypeExplain
|
||||||
|
SQLTypeShow
|
||||||
|
SQLTypeSet
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClassifySQL analyzes a SQL statement and returns its type
|
||||||
|
func ClassifySQL(sql string) SQLStatementType {
|
||||||
|
// Normalize: trim and convert to uppercase for analysis
|
||||||
|
normalized := strings.TrimSpace(strings.ToUpper(sql))
|
||||||
|
|
||||||
|
if normalized == "" {
|
||||||
|
return SQLTypeUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove comments
|
||||||
|
normalized = regexp.MustCompile(`--.*`).ReplaceAllString(normalized, "")
|
||||||
|
normalized = regexp.MustCompile(`/\*.*?\*/`).ReplaceAllString(normalized, "")
|
||||||
|
normalized = strings.TrimSpace(normalized)
|
||||||
|
|
||||||
|
// Check statement type
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(normalized, "SELECT"):
|
||||||
|
return SQLTypeSelect
|
||||||
|
case strings.HasPrefix(normalized, "INSERT"):
|
||||||
|
return SQLTypeInsert
|
||||||
|
case strings.HasPrefix(normalized, "UPDATE"):
|
||||||
|
return SQLTypeUpdate
|
||||||
|
case strings.HasPrefix(normalized, "DELETE"):
|
||||||
|
return SQLTypeDelete
|
||||||
|
case strings.HasPrefix(normalized, "TRUNCATE"):
|
||||||
|
return SQLTypeTruncate
|
||||||
|
case strings.HasPrefix(normalized, "CREATE"):
|
||||||
|
return SQLTypeDDL
|
||||||
|
case strings.HasPrefix(normalized, "ALTER"):
|
||||||
|
return SQLTypeDDL
|
||||||
|
case strings.HasPrefix(normalized, "DROP"):
|
||||||
|
return SQLTypeDDL
|
||||||
|
case strings.HasPrefix(normalized, "EXPLAIN"):
|
||||||
|
return SQLTypeExplain
|
||||||
|
case strings.HasPrefix(normalized, "SHOW"):
|
||||||
|
return SQLTypeShow
|
||||||
|
case strings.HasPrefix(normalized, "SET"):
|
||||||
|
return SQLTypeSet
|
||||||
|
default:
|
||||||
|
return SQLTypeUnknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsWriteOperation returns true if the SQL statement modifies data
|
||||||
|
func IsWriteOperation(sqlType SQLStatementType) bool {
|
||||||
|
switch sqlType {
|
||||||
|
case SQLTypeInsert, SQLTypeUpdate, SQLTypeDelete, SQLTypeTruncate, SQLTypeDDL:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsReadOnlyMode returns whether the source is in read-only mode
|
||||||
|
func (s *Source) IsReadOnlyMode() bool {
|
||||||
|
return s.ReadOnlyMode && !s.EnableWriteMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanExecuteWrite checks if a write operation is allowed
|
||||||
|
func (s *Source) CanExecuteWrite(sql string) error {
|
||||||
|
sqlType := ClassifySQL(sql)
|
||||||
|
|
||||||
|
if IsWriteOperation(sqlType) && s.IsReadOnlyMode() {
|
||||||
|
return &StructuredError{
|
||||||
|
Code: ErrCodeReadOnlyViolation,
|
||||||
|
Message: "Write operations are not allowed in read-only mode. Set enableWriteMode: true to allow writes.",
|
||||||
|
Details: map[string]any{
|
||||||
|
"sql_type": sqlType,
|
||||||
|
"read_only_mode": s.ReadOnlyMode,
|
||||||
|
"enable_write_mode": s.EnableWriteMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyQueryLimits applies row limits to a SQL query for MCP security compliance.
|
||||||
|
// Context timeout management is the responsibility of the caller (following Go best practices).
|
||||||
|
// Returns potentially modified SQL with LIMIT clause for SELECT queries.
|
||||||
|
func (s *Source) ApplyQueryLimits(sql string) (string, error) {
|
||||||
|
sqlType := ClassifySQL(sql)
|
||||||
|
|
||||||
|
// Apply row limit only to SELECT queries
|
||||||
|
if sqlType == SQLTypeSelect && s.MaxRowLimit > 0 {
|
||||||
|
// Check if query already has LIMIT clause
|
||||||
|
normalized := strings.ToUpper(sql)
|
||||||
|
if !strings.Contains(normalized, " LIMIT ") {
|
||||||
|
// Add LIMIT clause - trim trailing whitespace and semicolon
|
||||||
|
sql = strings.TrimSpace(sql)
|
||||||
|
sql = strings.TrimSuffix(sql, ";")
|
||||||
|
sql = fmt.Sprintf("%s LIMIT %d", sql, s.MaxRowLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sql, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedactSQL redacts sensitive values from SQL for telemetry
|
||||||
|
func RedactSQL(sql string) string {
|
||||||
|
// Redact string literals
|
||||||
|
sql = regexp.MustCompile(`'[^']*'`).ReplaceAllString(sql, "'***'")
|
||||||
|
|
||||||
|
// Redact numbers that might be sensitive
|
||||||
|
sql = regexp.MustCompile(`\b\d{10,}\b`).ReplaceAllString(sql, "***")
|
||||||
|
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmitTelemetry logs a telemetry event in structured JSON format
|
||||||
|
func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) {
|
||||||
|
if !s.EnableTelemetry {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set cluster ID if not already set
|
||||||
|
if event.ClusterID == "" {
|
||||||
|
event.ClusterID = s.ClusterID
|
||||||
|
if event.ClusterID == "" {
|
||||||
|
event.ClusterID = s.Database // Fallback to database name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set database and user
|
||||||
|
if event.Database == "" {
|
||||||
|
event.Database = s.Database
|
||||||
|
}
|
||||||
|
if event.User == "" {
|
||||||
|
event.User = s.User
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log as structured JSON
|
||||||
|
if s.TelemetryVerbose {
|
||||||
|
jsonBytes, _ := json.Marshal(event)
|
||||||
|
slog.Info("CockroachDB MCP Telemetry", "event", string(jsonBytes))
|
||||||
|
} else {
|
||||||
|
// Minimal logging
|
||||||
|
slog.Info("CockroachDB MCP",
|
||||||
|
"tool", event.ToolName,
|
||||||
|
"status", event.Status,
|
||||||
|
"latency_ms", event.LatencyMs,
|
||||||
|
"error_code", event.ErrorCode,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) {
|
||||||
|
//nolint:all
|
||||||
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||||
|
defer span.End()
|
||||||
|
|
||||||
|
userAgent, err := util.UserAgentFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
userAgent = "genai-toolbox"
|
||||||
|
}
|
||||||
|
if queryParams == nil {
|
||||||
|
queryParams = make(map[string]string)
|
||||||
|
}
|
||||||
|
if _, ok := queryParams["application_name"]; !ok {
|
||||||
|
queryParams["application_name"] = userAgent
|
||||||
|
}
|
||||||
|
|
||||||
|
connURL := &url.URL{
|
||||||
|
Scheme: "postgres",
|
||||||
|
User: url.UserPassword(user, pass),
|
||||||
|
Host: fmt.Sprintf("%s:%s", host, port),
|
||||||
|
Path: dbname,
|
||||||
|
RawQuery: ConvertParamMapToRawQuery(queryParams),
|
||||||
|
}
|
||||||
|
|
||||||
|
var pool *pgxpool.Pool
|
||||||
|
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||||
|
pool, err = pgxpool.New(ctx, connURL.String())
|
||||||
|
if err == nil {
|
||||||
|
err = pool.Ping(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return pool, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt < maxRetries {
|
||||||
|
backoff := baseDelay * time.Duration(math.Pow(2, float64(attempt)))
|
||||||
|
time.Sleep(backoff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to connect to CockroachDB after %d retries: %w", maxRetries, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertParamMapToRawQuery(queryParams map[string]string) string {
|
||||||
|
values := url.Values{}
|
||||||
|
for k, v := range queryParams {
|
||||||
|
values.Add(k, v)
|
||||||
|
}
|
||||||
|
return values.Encode()
|
||||||
|
}
|
||||||
224
internal/sources/cockroachdb/cockroachdb_test.go
Normal file
224
internal/sources/cockroachdb/cockroachdb_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cockroachdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCockroachDBSourceConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
yaml string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
yaml: `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
password: ""
|
||||||
|
database: defaultdb
|
||||||
|
maxRetries: 5
|
||||||
|
retryBaseDelay: 500ms
|
||||||
|
queryParams:
|
||||||
|
sslmode: disable
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with optional queryParams",
|
||||||
|
yaml: `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
password: testpass
|
||||||
|
database: testdb
|
||||||
|
queryParams:
|
||||||
|
sslmode: require
|
||||||
|
sslcert: /path/to/cert
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with custom retry settings",
|
||||||
|
yaml: `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
password: ""
|
||||||
|
database: defaultdb
|
||||||
|
maxRetries: 10
|
||||||
|
retryBaseDelay: 1s
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "without password (insecure mode)",
|
||||||
|
yaml: `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
database: defaultdb
|
||||||
|
`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
decoder := yaml.NewDecoder(strings.NewReader(tt.yaml))
|
||||||
|
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg == nil {
|
||||||
|
t.Fatal("expected config but got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's the right type
|
||||||
|
cockroachCfg, ok := cfg.(Config)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Config type, got %T", cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SourceConfigType
|
||||||
|
if cockroachCfg.SourceConfigType() != SourceType {
|
||||||
|
t.Errorf("expected SourceConfigType %q, got %q", SourceType, cockroachCfg.SourceConfigType())
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("✅ Config parsed successfully: %+v", cockroachCfg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCockroachDBSourceType(t *testing.T) {
|
||||||
|
yamlContent := `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
password: ""
|
||||||
|
database: defaultdb
|
||||||
|
`
|
||||||
|
decoder := yaml.NewDecoder(strings.NewReader(yamlContent))
|
||||||
|
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.SourceConfigType() != "cockroachdb" {
|
||||||
|
t.Errorf("expected SourceConfigType 'cockroachdb', got %q", cfg.SourceConfigType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCockroachDBDefaultValues(t *testing.T) {
|
||||||
|
yamlContent := `
|
||||||
|
name: test-cockroachdb
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
password: ""
|
||||||
|
database: defaultdb
|
||||||
|
`
|
||||||
|
decoder := yaml.NewDecoder(strings.NewReader(yamlContent))
|
||||||
|
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cockroachCfg, ok := cfg.(Config)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected Config type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check default values
|
||||||
|
if cockroachCfg.MaxRetries != 5 {
|
||||||
|
t.Errorf("expected default MaxRetries 5, got %d", cockroachCfg.MaxRetries)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cockroachCfg.RetryBaseDelay != "500ms" {
|
||||||
|
t.Errorf("expected default RetryBaseDelay '500ms', got %q", cockroachCfg.RetryBaseDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("✅ Default values set correctly")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertParamMapToRawQuery(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
params map[string]string
|
||||||
|
want []string // Expected substrings in any order
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty params",
|
||||||
|
params: map[string]string{},
|
||||||
|
want: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single param",
|
||||||
|
params: map[string]string{
|
||||||
|
"sslmode": "disable",
|
||||||
|
},
|
||||||
|
want: []string{"sslmode=disable"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple params",
|
||||||
|
params: map[string]string{
|
||||||
|
"sslmode": "require",
|
||||||
|
"application_name": "test-app",
|
||||||
|
},
|
||||||
|
want: []string{"sslmode=require", "application_name=test-app"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConvertParamMapToRawQuery(tt.params)
|
||||||
|
|
||||||
|
if len(tt.want) == 0 {
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("expected empty string, got %q", result)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that all expected substrings are in the result
|
||||||
|
for _, want := range tt.want {
|
||||||
|
if !contains(result, want) {
|
||||||
|
t.Errorf("expected result to contain %q, got %q", want, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("✅ Query string: %s", result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return strings.Contains(s, substr)
|
||||||
|
}
|
||||||
455
internal/sources/cockroachdb/security_test.go
Normal file
455
internal/sources/cockroachdb/security_test.go
Normal file
@@ -0,0 +1,455 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cockroachdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestClassifySQL tests SQL statement classification
|
||||||
|
func TestClassifySQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sql string
|
||||||
|
expected SQLStatementType
|
||||||
|
}{
|
||||||
|
{"SELECT", "SELECT * FROM users", SQLTypeSelect},
|
||||||
|
{"SELECT with spaces", " SELECT * FROM users ", SQLTypeSelect},
|
||||||
|
{"SELECT with comment", "-- comment\nSELECT * FROM users", SQLTypeSelect},
|
||||||
|
{"INSERT", "INSERT INTO users (name) VALUES ('alice')", SQLTypeInsert},
|
||||||
|
{"UPDATE", "UPDATE users SET name='bob' WHERE id=1", SQLTypeUpdate},
|
||||||
|
{"DELETE", "DELETE FROM users WHERE id=1", SQLTypeDelete},
|
||||||
|
{"CREATE TABLE", "CREATE TABLE users (id UUID PRIMARY KEY)", SQLTypeDDL},
|
||||||
|
{"ALTER TABLE", "ALTER TABLE users ADD COLUMN email STRING", SQLTypeDDL},
|
||||||
|
{"DROP TABLE", "DROP TABLE users", SQLTypeDDL},
|
||||||
|
{"TRUNCATE", "TRUNCATE TABLE users", SQLTypeTruncate},
|
||||||
|
{"EXPLAIN", "EXPLAIN SELECT * FROM users", SQLTypeExplain},
|
||||||
|
{"SHOW", "SHOW TABLES", SQLTypeShow},
|
||||||
|
{"SET", "SET application_name = 'myapp'", SQLTypeSet},
|
||||||
|
{"Empty", "", SQLTypeUnknown},
|
||||||
|
{"Lowercase select", "select * from users", SQLTypeSelect},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ClassifySQL(tt.sql)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ClassifySQL(%q) = %v, want %v", tt.sql, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsWriteOperation tests write operation detection
|
||||||
|
func TestIsWriteOperation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
sqlType SQLStatementType
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{SQLTypeSelect, false},
|
||||||
|
{SQLTypeInsert, true},
|
||||||
|
{SQLTypeUpdate, true},
|
||||||
|
{SQLTypeDelete, true},
|
||||||
|
{SQLTypeTruncate, true},
|
||||||
|
{SQLTypeDDL, true},
|
||||||
|
{SQLTypeExplain, false},
|
||||||
|
{SQLTypeShow, false},
|
||||||
|
{SQLTypeSet, false},
|
||||||
|
{SQLTypeUnknown, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.sqlType.String(), func(t *testing.T) {
|
||||||
|
result := IsWriteOperation(tt.sqlType)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsWriteOperation(%v) = %v, want %v", tt.sqlType, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper for SQLStatementType to string
|
||||||
|
func (s SQLStatementType) String() string {
|
||||||
|
switch s {
|
||||||
|
case SQLTypeSelect:
|
||||||
|
return "SELECT"
|
||||||
|
case SQLTypeInsert:
|
||||||
|
return "INSERT"
|
||||||
|
case SQLTypeUpdate:
|
||||||
|
return "UPDATE"
|
||||||
|
case SQLTypeDelete:
|
||||||
|
return "DELETE"
|
||||||
|
case SQLTypeDDL:
|
||||||
|
return "DDL"
|
||||||
|
case SQLTypeTruncate:
|
||||||
|
return "TRUNCATE"
|
||||||
|
case SQLTypeExplain:
|
||||||
|
return "EXPLAIN"
|
||||||
|
case SQLTypeShow:
|
||||||
|
return "SHOW"
|
||||||
|
case SQLTypeSet:
|
||||||
|
return "SET"
|
||||||
|
default:
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCanExecuteWrite tests write operation enforcement
|
||||||
|
func TestCanExecuteWrite(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
readOnlyMode bool
|
||||||
|
enableWriteMode bool
|
||||||
|
sql string
|
||||||
|
expectError bool
|
||||||
|
errorCode string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SELECT in read-only mode",
|
||||||
|
readOnlyMode: true,
|
||||||
|
enableWriteMode: false,
|
||||||
|
sql: "SELECT * FROM users",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INSERT in read-only mode",
|
||||||
|
readOnlyMode: true,
|
||||||
|
enableWriteMode: false,
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||||
|
expectError: true,
|
||||||
|
errorCode: ErrCodeReadOnlyViolation,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INSERT with write mode enabled",
|
||||||
|
readOnlyMode: false,
|
||||||
|
enableWriteMode: true,
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CREATE TABLE in read-only mode",
|
||||||
|
readOnlyMode: true,
|
||||||
|
enableWriteMode: false,
|
||||||
|
sql: "CREATE TABLE test (id UUID PRIMARY KEY)",
|
||||||
|
expectError: true,
|
||||||
|
errorCode: ErrCodeReadOnlyViolation,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CREATE TABLE with write mode enabled",
|
||||||
|
readOnlyMode: false,
|
||||||
|
enableWriteMode: true,
|
||||||
|
sql: "CREATE TABLE test (id UUID PRIMARY KEY)",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
source := &Source{
|
||||||
|
Config: Config{
|
||||||
|
ReadOnlyMode: tt.readOnlyMode,
|
||||||
|
EnableWriteMode: tt.enableWriteMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := source.CanExecuteWrite(tt.sql)
|
||||||
|
|
||||||
|
if tt.expectError {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("Expected error but got nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
structErr, ok := err.(*StructuredError)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Expected StructuredError but got %T", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if structErr.Code != tt.errorCode {
|
||||||
|
t.Errorf("Expected error code %s but got %s", tt.errorCode, structErr.Code)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error but got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyQueryLimits tests query limit application
|
||||||
|
func TestApplyQueryLimits(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sql string
|
||||||
|
maxRowLimit int
|
||||||
|
expectedSQL string
|
||||||
|
shouldAddLimit bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "SELECT without LIMIT",
|
||||||
|
sql: "SELECT * FROM users",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||||
|
shouldAddLimit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SELECT with existing LIMIT",
|
||||||
|
sql: "SELECT * FROM users LIMIT 50",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "SELECT * FROM users LIMIT 50",
|
||||||
|
shouldAddLimit: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SELECT without LIMIT and semicolon",
|
||||||
|
sql: "SELECT * FROM users;",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||||
|
shouldAddLimit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SELECT with trailing newline and semicolon",
|
||||||
|
sql: "SELECT * FROM users;\n",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||||
|
shouldAddLimit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SELECT with multiline and semicolon",
|
||||||
|
sql: "\n\tSELECT *\n\tFROM users\n\tORDER BY id;\n",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "SELECT *\n\tFROM users\n\tORDER BY id LIMIT 100",
|
||||||
|
shouldAddLimit: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INSERT should not have LIMIT added",
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||||
|
maxRowLimit: 100,
|
||||||
|
expectedSQL: "INSERT INTO users (name) VALUES ('alice')",
|
||||||
|
shouldAddLimit: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
source := &Source{
|
||||||
|
Config: Config{
|
||||||
|
MaxRowLimit: tt.maxRowLimit,
|
||||||
|
QueryTimeoutSec: 0, // Timeout now managed by caller
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
modifiedSQL, err := source.ApplyQueryLimits(tt.sql)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modifiedSQL != tt.expectedSQL {
|
||||||
|
t.Errorf("Expected SQL:\n%s\nGot:\n%s", tt.expectedSQL, modifiedSQL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyQueryTimeout tests that timeout is managed by caller (not source)
|
||||||
|
func TestApplyQueryTimeout(t *testing.T) {
|
||||||
|
source := &Source{
|
||||||
|
Config: Config{
|
||||||
|
QueryTimeoutSec: 5, // Documented recommended timeout
|
||||||
|
MaxRowLimit: 0, // Don't add LIMIT
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Caller creates timeout context (following Go best practices)
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, time.Duration(source.QueryTimeoutSec)*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Apply query limits (doesn't modify context anymore)
|
||||||
|
modifiedSQL, err := source.ApplyQueryLimits("SELECT * FROM users")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify context has deadline (managed by caller)
|
||||||
|
deadline, ok := ctx.Deadline()
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected deadline to be set but it wasn't")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deadline is approximately 5 seconds from now
|
||||||
|
expectedDeadline := time.Now().Add(5 * time.Second)
|
||||||
|
diff := deadline.Sub(expectedDeadline)
|
||||||
|
if diff < 0 {
|
||||||
|
diff = -diff
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow 1 second tolerance
|
||||||
|
if diff > time.Second {
|
||||||
|
t.Errorf("Deadline diff too large: %v", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SQL is unchanged (LIMIT not added since MaxRowLimit=0)
|
||||||
|
if modifiedSQL != "SELECT * FROM users" {
|
||||||
|
t.Errorf("Expected SQL unchanged, got: %s", modifiedSQL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRedactSQL tests SQL redaction for telemetry
|
||||||
|
func TestRedactSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sql string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "String literal redaction",
|
||||||
|
sql: "SELECT * FROM users WHERE name='alice' AND email='alice@example.com'",
|
||||||
|
expected: "SELECT * FROM users WHERE name='***' AND email='***'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Long number redaction",
|
||||||
|
sql: "SELECT * FROM users WHERE ssn=1234567890123",
|
||||||
|
expected: "SELECT * FROM users WHERE ssn=***",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Short numbers not redacted",
|
||||||
|
sql: "SELECT * FROM users WHERE age=25",
|
||||||
|
expected: "SELECT * FROM users WHERE age=25",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple sensitive values",
|
||||||
|
sql: "INSERT INTO users (name, email, phone) VALUES ('bob', 'bob@example.com', '5551234567')",
|
||||||
|
expected: "INSERT INTO users (name, email, phone) VALUES ('***', '***', '***')",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := RedactSQL(tt.sql)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("RedactSQL:\nGot: %s\nExpected: %s", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsReadOnlyMode tests read-only mode detection
|
||||||
|
func TestIsReadOnlyMode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
readOnlyMode bool
|
||||||
|
enableWriteMode bool
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"Read-only by default", true, false, true},
|
||||||
|
{"Write mode enabled", false, true, false},
|
||||||
|
{"Both false", false, false, false},
|
||||||
|
{"Read-only overridden by write mode", true, true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
source := &Source{
|
||||||
|
Config: Config{
|
||||||
|
ReadOnlyMode: tt.readOnlyMode,
|
||||||
|
EnableWriteMode: tt.enableWriteMode,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := source.IsReadOnlyMode()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsReadOnlyMode() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStructuredError tests error formatting
|
||||||
|
func TestStructuredError(t *testing.T) {
|
||||||
|
err := &StructuredError{
|
||||||
|
Code: ErrCodeReadOnlyViolation,
|
||||||
|
Message: "Write operations not allowed",
|
||||||
|
Details: map[string]any{
|
||||||
|
"sql_type": "INSERT",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
errorStr := err.Error()
|
||||||
|
if !strings.Contains(errorStr, ErrCodeReadOnlyViolation) {
|
||||||
|
t.Errorf("Error string should contain error code: %s", errorStr)
|
||||||
|
}
|
||||||
|
if !strings.Contains(errorStr, "Write operations not allowed") {
|
||||||
|
t.Errorf("Error string should contain message: %s", errorStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultSecuritySettings tests that security defaults are correct
|
||||||
|
func TestDefaultSecuritySettings(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create a minimal YAML config
|
||||||
|
yamlData := `name: test
|
||||||
|
type: cockroachdb
|
||||||
|
host: localhost
|
||||||
|
port: "26257"
|
||||||
|
user: root
|
||||||
|
database: defaultdb
|
||||||
|
`
|
||||||
|
|
||||||
|
var cfg Config
|
||||||
|
if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal YAML: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults through newConfig logic manually
|
||||||
|
cfg.MaxRetries = 5
|
||||||
|
cfg.RetryBaseDelay = "500ms"
|
||||||
|
cfg.ReadOnlyMode = true
|
||||||
|
cfg.EnableWriteMode = false
|
||||||
|
cfg.MaxRowLimit = 1000
|
||||||
|
cfg.QueryTimeoutSec = 30
|
||||||
|
cfg.EnableTelemetry = true
|
||||||
|
cfg.TelemetryVerbose = false
|
||||||
|
|
||||||
|
_ = ctx // prevent unused
|
||||||
|
|
||||||
|
// Verify MCP security defaults
|
||||||
|
if !cfg.ReadOnlyMode {
|
||||||
|
t.Error("ReadOnlyMode should be true by default")
|
||||||
|
}
|
||||||
|
if cfg.EnableWriteMode {
|
||||||
|
t.Error("EnableWriteMode should be false by default")
|
||||||
|
}
|
||||||
|
if cfg.MaxRowLimit != 1000 {
|
||||||
|
t.Errorf("MaxRowLimit should be 1000, got %d", cfg.MaxRowLimit)
|
||||||
|
}
|
||||||
|
if cfg.QueryTimeoutSec != 30 {
|
||||||
|
t.Errorf("QueryTimeoutSec should be 30, got %d", cfg.QueryTimeoutSec)
|
||||||
|
}
|
||||||
|
if !cfg.EnableTelemetry {
|
||||||
|
t.Error("EnableTelemetry should be true by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
// Copyright © 2025, Oracle and/or its affiliates.
|
// Copyright © 2025, Oracle and/or its affiliates.
|
||||||
|
|
||||||
package oracle
|
package oracle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbcreatecluster
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -122,44 +124,49 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok || project == "" {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
clusterID, ok := paramsMap["cluster"].(string)
|
clusterID, ok := paramsMap["cluster"].(string)
|
||||||
if !ok || clusterID == "" {
|
if !ok || clusterID == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
password, ok := paramsMap["password"].(string)
|
password, ok := paramsMap["password"].(string)
|
||||||
if !ok || password == "" {
|
if !ok || password == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'password' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'password' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
network, ok := paramsMap["network"].(string)
|
network, ok := paramsMap["network"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'network' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'network' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
user, ok := paramsMap["user"].(string)
|
user, ok := paramsMap["user"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'user' parameter; expected a string", nil)
|
||||||
|
}
|
||||||
|
resp, err := source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken))
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbcreateinstance
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -123,36 +125,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok || project == "" {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok || location == "" {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'location' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok || cluster == "" {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
instanceID, ok := paramsMap["instance"].(string)
|
instanceID, ok := paramsMap["instance"].(string)
|
||||||
if !ok || instanceID == "" {
|
if !ok || instanceID == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'instance' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
instanceType, ok := paramsMap["instanceType"].(string)
|
instanceType, ok := paramsMap["instanceType"].(string)
|
||||||
if !ok || (instanceType != "READ_POOL" && instanceType != "PRIMARY") {
|
if !ok || (instanceType != "READ_POOL" && instanceType != "PRIMARY") {
|
||||||
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
|
return nil, util.NewAgentError("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
displayName, _ := paramsMap["displayName"].(string)
|
displayName, _ := paramsMap["displayName"].(string)
|
||||||
@@ -161,11 +163,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if instanceType == "READ_POOL" {
|
if instanceType == "READ_POOL" {
|
||||||
nodeCount, ok = paramsMap["nodeCount"].(int)
|
nodeCount, ok = paramsMap["nodeCount"].(int)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL")
|
return nil, util.NewAgentError("invalid 'nodeCount' parameter; expected an integer for READ_POOL", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken))
|
resp, err := source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbcreateuser
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -122,43 +124,43 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok || project == "" {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok || location == "" {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing'location' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing'location' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok || cluster == "" {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, ok := paramsMap["user"].(string)
|
userID, ok := paramsMap["user"].(string)
|
||||||
if !ok || userID == "" {
|
if !ok || userID == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'user' parameter; expected a non-empty string")
|
return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a non-empty string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
userType, ok := paramsMap["userType"].(string)
|
userType, ok := paramsMap["userType"].(string)
|
||||||
if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") {
|
if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") {
|
||||||
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
|
return nil, util.NewAgentError("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'", nil)
|
||||||
}
|
}
|
||||||
var password string
|
var password string
|
||||||
|
|
||||||
if userType == "ALLOYDB_BUILT_IN" {
|
if userType == "ALLOYDB_BUILT_IN" {
|
||||||
password, ok = paramsMap["password"].(string)
|
password, ok = paramsMap["password"].(string)
|
||||||
if !ok || password == "" {
|
if !ok || password == "" {
|
||||||
return nil, fmt.Errorf("password is required when userType is ALLOYDB_BUILT_IN")
|
return nil, util.NewAgentError("password is required when userType is ALLOYDB_BUILT_IN", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,7 +172,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID)
|
resp, err := source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbgetcluster
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,28 +122,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.GetCluster(ctx, project, location, cluster, string(accessToken))
|
resp, err := source.GetCluster(ctx, project, location, cluster, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbgetinstance
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
instance, ok := paramsMap["instance"].(string)
|
instance, ok := paramsMap["instance"].(string)
|
||||||
if !ok {
|
if !ok || instance == "" {
|
||||||
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.GetInstance(ctx, project, location, cluster, instance, string(accessToken))
|
resp, err := source.GetInstance(ctx, project, location, cluster, instance, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydbgetuser
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
user, ok := paramsMap["user"].(string)
|
user, ok := paramsMap["user"].(string)
|
||||||
if !ok {
|
if !ok || user == "" {
|
||||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.GetUsers(ctx, project, location, cluster, user, string(accessToken))
|
resp, err := source.GetUsers(ctx, project, location, cluster, user, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydblistclusters
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,24 +120,28 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.ListCluster(ctx, project, location, string(accessToken))
|
resp, err := source.ListCluster(ctx, project, location, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydblistinstances
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
return nil, util.NewAgentError("invalid 'cluster' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.ListInstance(ctx, project, location, cluster, string(accessToken))
|
resp, err := source.ListInstance(ctx, project, location, cluster, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package alloydblistusers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok || project == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok || location == "" {
|
||||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
cluster, ok := paramsMap["cluster"].(string)
|
cluster, ok := paramsMap["cluster"].(string)
|
||||||
if !ok {
|
if !ok || cluster == "" {
|
||||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.ListUsers(ctx, project, location, cluster, string(accessToken))
|
resp, err := source.ListUsers(ctx, project, location, cluster, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,25 +214,25 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool's logic.
|
// Invoke executes the tool's logic.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
project, ok := paramsMap["project"].(string)
|
project, ok := paramsMap["project"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("missing 'project' parameter")
|
return nil, util.NewAgentError("missing 'project' parameter", nil)
|
||||||
}
|
}
|
||||||
location, ok := paramsMap["location"].(string)
|
location, ok := paramsMap["location"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("missing 'location' parameter")
|
return nil, util.NewAgentError("missing 'location' parameter", nil)
|
||||||
}
|
}
|
||||||
operation, ok := paramsMap["operation"].(string)
|
operation, ok := paramsMap["operation"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
return nil, util.NewAgentError("missing 'operation' parameter", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||||
@@ -246,14 +247,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
for retries < maxRetries {
|
for retries < maxRetries {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err())
|
return nil, util.NewAgentError("timed out waiting for operation", ctx.Err())
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken))
|
op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.ProcessGeneralError(err)
|
||||||
} else if op != nil {
|
}
|
||||||
|
if op != nil {
|
||||||
return op, nil
|
return op, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -264,7 +266,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
retries++
|
retries++
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("exceeded max retries waiting for operation")
|
return nil, util.NewAgentError("exceeded max retries waiting for operation", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ package alloydbainl
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
@@ -127,10 +129,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceParams := params.AsSlice()
|
sliceParams := params.AsSlice()
|
||||||
@@ -143,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
resp, err := source.RunSQL(ctx, t.Statement, allParamValues)
|
resp, err := source.RunSQL(ctx, t.Statement, allParamValues)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("%w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
|
return nil, util.NewClientServerError(fmt.Sprintf("error running SQL query: %v. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues), http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigqueryanalyzecontribution
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
@@ -27,6 +28,7 @@ import (
|
|||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
@@ -154,21 +156,21 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke runs the contribution analysis.
|
// Invoke runs the contribution analysis.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
inputData, ok := paramsMap["input_data"].(string)
|
inputData, ok := paramsMap["input_data"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast input_data parameter %s", paramsMap["input_data"]), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||||
@@ -186,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
|
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"]), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok {
|
if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok {
|
||||||
@@ -195,7 +197,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if val, ok := paramsMap["pruning_method"].(string); ok {
|
if val, ok := paramsMap["pruning_method"].(string); ok {
|
||||||
upperVal := strings.ToUpper(val)
|
upperVal := strings.ToUpper(val)
|
||||||
if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" {
|
if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" {
|
||||||
return nil, fmt.Errorf("invalid pruning_method: %s", val)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid pruning_method: %s", val), nil)
|
||||||
}
|
}
|
||||||
options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal))
|
options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal))
|
||||||
}
|
}
|
||||||
@@ -207,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := source.BigQuerySession()(ctx)
|
session, err := source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
if session != nil {
|
if session != nil {
|
||||||
connProps = []*bigqueryapi.ConnectionProperty{
|
connProps = []*bigqueryapi.ConnectionProperty{
|
||||||
@@ -216,22 +218,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
statementType := dryRunJob.Statistics.Query.StatementType
|
||||||
if statementType != "SELECT" {
|
if statementType != "SELECT" {
|
||||||
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
|
return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
queryStats := dryRunJob.Statistics.Query
|
queryStats := dryRunJob.Statistics.Query
|
||||||
if queryStats != nil {
|
if queryStats != nil {
|
||||||
for _, tableRef := range queryStats.ReferencedTables {
|
for _, tableRef := range queryStats.ReferencedTables {
|
||||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
return nil, util.NewAgentError(fmt.Sprintf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
|
return nil, util.NewAgentError("could not analyze query in input_data to validate against allowed datasets", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||||
@@ -245,10 +247,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
case 2: // dataset.table
|
case 2: // dataset.table
|
||||||
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData), nil)
|
||||||
}
|
}
|
||||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
||||||
@@ -268,7 +270,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
// Otherwise, a new session will be created by the first query.
|
// Otherwise, a new session will be created by the first query.
|
||||||
session, err := source.BigQuerySession()(ctx)
|
session, err := source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if session != nil {
|
if session != nil {
|
||||||
@@ -281,15 +283,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
createModelJob, err := createModelQuery.Run(ctx)
|
createModelJob, err := createModelQuery.Run(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to start create model job: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
status, err := createModelJob.Wait(ctx)
|
status, err := createModelJob.Wait(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to wait for create model job: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
if err := status.Err(); err != nil {
|
if err := status.Err(); err != nil {
|
||||||
return nil, fmt.Errorf("create model job failed: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine the session ID to use for subsequent queries.
|
// Determine the session ID to use for subsequent queries.
|
||||||
@@ -300,12 +302,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
} else if status.Statistics != nil && status.Statistics.SessionInfo != nil {
|
} else if status.Statistics != nil && status.Statistics.SessionInfo != nil {
|
||||||
sessionID = status.Statistics.SessionInfo.SessionID
|
sessionID = status.Statistics.SessionInfo.SessionID
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("failed to get or create a BigQuery session ID")
|
return nil, util.NewClientServerError("failed to get or create a BigQuery session ID", http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||||
connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||||
return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps)
|
|
||||||
|
resp, err := source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -35,6 +35,8 @@ import (
|
|||||||
|
|
||||||
const resourceType string = "bigquery-conversational-analytics"
|
const resourceType string = "bigquery-conversational-analytics"
|
||||||
|
|
||||||
|
const gdaURLFormat = "https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat"
|
||||||
|
|
||||||
const instructions = `**INSTRUCTIONS - FOLLOW THESE RULES:**
|
const instructions = `**INSTRUCTIONS - FOLLOW THESE RULES:**
|
||||||
1. **CONTENT:** Your answer should present the supporting data and then provide a conclusion based on that data.
|
1. **CONTENT:** Your answer should present the supporting data and then provide a conclusion based on that data.
|
||||||
2. **OUTPUT FORMAT:** Your entire response MUST be in plain text format ONLY.
|
2. **OUTPUT FORMAT:** Your entire response MUST be in plain text format ONLY.
|
||||||
@@ -172,10 +174,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenStr string
|
var tokenStr string
|
||||||
@@ -184,26 +186,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
// Use client-side access token
|
// Use client-side access token
|
||||||
if accessToken == "" {
|
if accessToken == "" {
|
||||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
||||||
}
|
}
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Get a token source for the Gemini Data Analytics API.
|
// Get a token source for the Gemini Data Analytics API.
|
||||||
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, nil)
|
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token source: %w", err)
|
return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use cloud-platform token source for Gemini Data Analytics API
|
// Use cloud-platform token source for Gemini Data Analytics API
|
||||||
if tokenSource == nil {
|
if tokenSource == nil {
|
||||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
token, err := tokenSource.Token()
|
token, err := tokenSource.Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
tokenStr = token.AccessToken
|
tokenStr = token.AccessToken
|
||||||
}
|
}
|
||||||
@@ -218,14 +220,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
var tableRefs []BQTableReference
|
var tableRefs []BQTableReference
|
||||||
if tableRefsJSON != "" {
|
if tableRefsJSON != "" {
|
||||||
if err := json.Unmarshal([]byte(tableRefsJSON), &tableRefs); err != nil {
|
if err := json.Unmarshal([]byte(tableRefsJSON), &tableRefs); err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse 'table_references' JSON string: %w", err)
|
return nil, util.NewAgentError("failed to parse 'table_references' JSON string", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||||
for _, tableRef := range tableRefs {
|
for _, tableRef := range tableRefs {
|
||||||
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -236,11 +238,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if location == "" {
|
if location == "" {
|
||||||
location = "us"
|
location = "us"
|
||||||
}
|
}
|
||||||
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1alpha/projects/%s/locations/%s:chat", projectID, location)
|
caURL := fmt.Sprintf(gdaURLFormat, projectID, location)
|
||||||
|
|
||||||
headers := map[string]string{
|
headers := map[string]string{
|
||||||
"Authorization": fmt.Sprintf("Bearer %s", tokenStr),
|
"Authorization": fmt.Sprintf("Bearer %s", tokenStr),
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
"X-Goog-API-Client": util.GDAClientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := CAPayload{
|
payload := CAPayload{
|
||||||
@@ -252,13 +255,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
},
|
},
|
||||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||||
},
|
},
|
||||||
ClientIdEnum: "GENAI_TOOLBOX",
|
ClientIdEnum: util.GDAClientID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the streaming API
|
// Call the streaming API
|
||||||
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
// getStream wraps network errors or non-200 responses
|
||||||
|
return nil, util.NewClientServerError("failed to get response from conversational analytics API", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
@@ -152,25 +153,25 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
sql, ok := paramsMap["sql"].(string)
|
sql, ok := paramsMap["sql"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil)
|
||||||
}
|
}
|
||||||
dryRun, ok := paramsMap["dry_run"].(bool)
|
dryRun, ok := paramsMap["dry_run"].(bool)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast dry_run parameter %s", paramsMap["dry_run"]), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
@@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
||||||
session, err = source.BigQuerySession()(ctx)
|
session, err = source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session for protected mode", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
connProps = []*bigqueryapi.ConnectionProperty{
|
connProps = []*bigqueryapi.ConnectionProperty{
|
||||||
{Key: "session_id", Value: session.ID},
|
{Key: "session_id", Value: session.ID},
|
||||||
@@ -187,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, util.NewClientServerError("query validation failed", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
statementType := dryRunJob.Statistics.Query.StatementType
|
||||||
@@ -195,13 +196,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
switch source.BigQueryWriteMode() {
|
switch source.BigQueryWriteMode() {
|
||||||
case bigqueryds.WriteModeBlocked:
|
case bigqueryds.WriteModeBlocked:
|
||||||
if statementType != "SELECT" {
|
if statementType != "SELECT" {
|
||||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
return nil, util.NewAgentError("write mode is 'blocked', only SELECT statements are allowed", nil)
|
||||||
}
|
}
|
||||||
case bigqueryds.WriteModeProtected:
|
case bigqueryds.WriteModeProtected:
|
||||||
if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil {
|
if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil {
|
||||||
if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID {
|
if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID {
|
||||||
return nil, fmt.Errorf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
|
return nil, util.NewAgentError(fmt.Sprintf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
|
||||||
"dataset of a BigQuery session, but destination was %q", dest.DatasetId)
|
"dataset of a BigQuery session, but destination was %q", dest.DatasetId), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,11 +210,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||||
switch statementType {
|
switch statementType {
|
||||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
return nil, util.NewAgentError(fmt.Sprintf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType), nil)
|
||||||
case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE":
|
case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE":
|
||||||
return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
return nil, util.NewAgentError(fmt.Sprintf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil)
|
||||||
case "CALL":
|
case "CALL":
|
||||||
return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
return nil, util.NewAgentError(fmt.Sprintf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use a map to avoid duplicate table names.
|
// Use a map to avoid duplicate table names.
|
||||||
@@ -244,7 +245,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
return nil, util.NewAgentError("could not parse tables from query to validate against allowed datasets", parseErr)
|
||||||
}
|
}
|
||||||
tableNames = parsedTables
|
tableNames = parsedTables
|
||||||
}
|
}
|
||||||
@@ -254,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if len(parts) == 3 {
|
if len(parts) == 3 {
|
||||||
projectID, datasetID := parts[0], parts[1]
|
projectID, datasetID := parts[0], parts[1]
|
||||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -264,7 +265,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if dryRunJob != nil {
|
if dryRunJob != nil {
|
||||||
jobJSON, err := json.MarshalIndent(dryRunJob, "", " ")
|
jobJSON, err := json.MarshalIndent(dryRunJob, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal dry run job to JSON: %w", err)
|
return nil, util.NewClientServerError("failed to marshal dry run job to JSON", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
return string(jobJSON), nil
|
return string(jobJSON), nil
|
||||||
}
|
}
|
||||||
@@ -275,10 +276,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
// Log the query executed for debugging.
|
// Log the query executed for debugging.
|
||||||
logger, err := util.LoggerFromContext(ctx)
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
||||||
return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps)
|
resp, err := source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.NewClientServerError("error running sql", http.StatusInternalServerError, err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigqueryforecast
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
@@ -133,34 +134,34 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
historyData, ok := paramsMap["history_data"].(string)
|
historyData, ok := paramsMap["history_data"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast history_data parameter %v", paramsMap["history_data"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast history_data parameter %v", paramsMap["history_data"]), nil)
|
||||||
}
|
}
|
||||||
timestampCol, ok := paramsMap["timestamp_col"].(string)
|
timestampCol, ok := paramsMap["timestamp_col"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"]), nil)
|
||||||
}
|
}
|
||||||
dataCol, ok := paramsMap["data_col"].(string)
|
dataCol, ok := paramsMap["data_col"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast data_col parameter %v", paramsMap["data_col"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast data_col parameter %v", paramsMap["data_col"]), nil)
|
||||||
}
|
}
|
||||||
idColsRaw, ok := paramsMap["id_cols"].([]any)
|
idColsRaw, ok := paramsMap["id_cols"].([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast id_cols parameter %v", paramsMap["id_cols"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast id_cols parameter %v", paramsMap["id_cols"]), nil)
|
||||||
}
|
}
|
||||||
var idCols []string
|
var idCols []string
|
||||||
for _, v := range idColsRaw {
|
for _, v := range idColsRaw {
|
||||||
s, ok := v.(string)
|
s, ok := v.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("id_cols contains non-string value: %v", v)
|
return nil, util.NewAgentError(fmt.Sprintf("id_cols contains non-string value: %v", v), nil)
|
||||||
}
|
}
|
||||||
idCols = append(idCols, s)
|
idCols = append(idCols, s)
|
||||||
}
|
}
|
||||||
@@ -169,13 +170,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if h, ok := paramsMap["horizon"].(float64); ok {
|
if h, ok := paramsMap["horizon"].(float64); ok {
|
||||||
horizon = int(h)
|
horizon = int(h)
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unable to cast horizon parameter %v", paramsMap["horizon"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast horizon parameter %v", paramsMap["horizon"]), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var historyDataSource string
|
var historyDataSource string
|
||||||
@@ -185,7 +186,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
session, err := source.BigQuerySession()(ctx)
|
session, err := source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
if session != nil {
|
if session != nil {
|
||||||
connProps = []*bigqueryapi.ConnectionProperty{
|
connProps = []*bigqueryapi.ConnectionProperty{
|
||||||
@@ -194,22 +195,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
statementType := dryRunJob.Statistics.Query.StatementType
|
||||||
if statementType != "SELECT" {
|
if statementType != "SELECT" {
|
||||||
return nil, fmt.Errorf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
|
return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
queryStats := dryRunJob.Statistics.Query
|
queryStats := dryRunJob.Statistics.Query
|
||||||
if queryStats != nil {
|
if queryStats != nil {
|
||||||
for _, tableRef := range queryStats.ReferencedTables {
|
for _, tableRef := range queryStats.ReferencedTables {
|
||||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
return nil, util.NewAgentError(fmt.Sprintf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("could not analyze query in history_data to validate against allowed datasets")
|
return nil, util.NewAgentError("could not analyze query in history_data to validate against allowed datasets", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||||
@@ -226,11 +227,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
projectID = source.BigQueryClient().Project()
|
projectID = source.BigQueryClient().Project()
|
||||||
datasetID = parts[0]
|
datasetID = parts[0]
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData), nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
|
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
|
||||||
@@ -251,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
session, err := source.BigQuerySession()(ctx)
|
session, err := source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
var connProps []*bigqueryapi.ConnectionProperty
|
var connProps []*bigqueryapi.ConnectionProperty
|
||||||
if session != nil {
|
if session != nil {
|
||||||
@@ -264,11 +265,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
// Log the query executed for debugging.
|
// Log the query executed for debugging.
|
||||||
logger, err := util.LoggerFromContext(ctx)
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
||||||
|
|
||||||
return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps)
|
resp, err := source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigquerygetdatasetinfo
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
@@ -120,38 +122,38 @@ type Tool struct {
|
|||||||
func (t Tool) ToConfig() tools.ToolConfig {
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
// Updated: Use fmt.Sprintf for formatting, pass nil as cause
|
||||||
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
datasetId, ok := mapParams[datasetKey].(string)
|
datasetId, ok := mapParams[datasetKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
|
|
||||||
metadata, err := dsHandle.Metadata(ctx)
|
metadata, err := dsHandle.Metadata(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return metadata, nil
|
return metadata, nil
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigquerygettableinfo
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
@@ -125,35 +127,35 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
datasetId, ok := mapParams[datasetKey].(string)
|
datasetId, ok := mapParams[datasetKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
tableId, ok := mapParams[tableKey].(string)
|
tableId, ok := mapParams[tableKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", tableKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
@@ -161,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
metadata, err := tableHandle.Metadata(ctx)
|
metadata, err := tableHandle.Metadata(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", projectId, datasetId, tableId, err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return metadata, nil
|
return metadata, nil
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ package bigquerylistdatasetids
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
"google.golang.org/api/iterator"
|
||||||
@@ -120,10 +122,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||||
@@ -132,12 +134,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
datasetIterator := bqClient.Datasets(ctx)
|
datasetIterator := bqClient.Datasets(ctx)
|
||||||
datasetIterator.ProjectID = projectId
|
datasetIterator.ProjectID = projectId
|
||||||
@@ -149,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to iterate through datasets: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove leading and trailing quotes
|
// Remove leading and trailing quotes
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigquerylisttableids
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
bigqueryapi "cloud.google.com/go/bigquery"
|
bigqueryapi "cloud.google.com/go/bigquery"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
"google.golang.org/api/iterator"
|
"google.golang.org/api/iterator"
|
||||||
@@ -123,31 +125,30 @@ type Tool struct {
|
|||||||
func (t Tool) ToConfig() tools.ToolConfig {
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
projectId, ok := mapParams[projectKey].(string)
|
projectId, ok := mapParams[projectKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
datasetId, ok := mapParams[datasetKey].(string)
|
datasetId, ok := mapParams[datasetKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||||
@@ -160,7 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", projectId, datasetId, err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove leading and trailing quotes
|
// Remove leading and trailing quotes
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigquerysearchcatalog
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||||
@@ -26,6 +27,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/iterator"
|
"google.golang.org/api/iterator"
|
||||||
)
|
)
|
||||||
@@ -186,28 +188,31 @@ func ExtractType(resourceString string) string {
|
|||||||
return typeMap[resourceString[lastIndex+1:]]
|
return typeMap[resourceString[lastIndex+1:]]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
pageSize := int32(paramsMap["pageSize"].(int))
|
pageSize := int32(paramsMap["pageSize"].(int))
|
||||||
prompt, _ := paramsMap["prompt"].(string)
|
prompt, _ := paramsMap["prompt"].(string)
|
||||||
|
|
||||||
projectIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["projectIds"].([]any), "string")
|
projectIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["projectIds"].([]any), "string")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't convert projectIds to array of strings: %s", err)
|
return nil, util.NewAgentError(fmt.Sprintf("can't convert projectIds to array of strings: %s", err), err)
|
||||||
}
|
}
|
||||||
projectIds := projectIdSlice.([]string)
|
projectIds := projectIdSlice.([]string)
|
||||||
|
|
||||||
datasetIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["datasetIds"].([]any), "string")
|
datasetIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["datasetIds"].([]any), "string")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't convert datasetIds to array of strings: %s", err)
|
return nil, util.NewAgentError(fmt.Sprintf("can't convert datasetIds to array of strings: %s", err), err)
|
||||||
}
|
}
|
||||||
datasetIds := datasetIdSlice.([]string)
|
datasetIds := datasetIdSlice.([]string)
|
||||||
|
|
||||||
typesSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["types"].([]any), "string")
|
typesSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["types"].([]any), "string")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't convert types to array of strings: %s", err)
|
return nil, util.NewAgentError(fmt.Sprintf("can't convert types to array of strings: %s", err), err)
|
||||||
}
|
}
|
||||||
types := typesSlice.([]string)
|
types := typesSlice.([]string)
|
||||||
|
|
||||||
@@ -223,17 +228,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
catalogClient, err = dataplexClientCreator(tokenStr)
|
catalogClient, err = dataplexClientCreator(tokenStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
return nil, util.NewClientServerError("error creating client from OAuth access token", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
it := catalogClient.SearchEntries(ctx, req)
|
it := catalogClient.SearchEntries(ctx, req)
|
||||||
if it == nil {
|
if it == nil {
|
||||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
|
return nil, util.NewClientServerError(fmt.Sprintf("failed to create search entries iterator for project %q", source.BigQueryProject()), http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var results []Response
|
var results []Response
|
||||||
@@ -243,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
entrySource := entry.DataplexEntry.GetEntrySource()
|
entrySource := entry.DataplexEntry.GetEntrySource()
|
||||||
resp := Response{
|
resp := Response{
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package bigquerysql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ import (
|
|||||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||||
)
|
)
|
||||||
@@ -103,11 +105,10 @@ type Tool struct {
|
|||||||
func (t Tool) ToConfig() tools.ToolConfig {
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||||
@@ -116,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
return nil, util.NewAgentError("unable to extract template params", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range t.Parameters {
|
for _, p := range t.Parameters {
|
||||||
@@ -127,13 +128,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
|
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
|
||||||
arrayParamValue, ok := value.([]any)
|
arrayParamValue, ok := value.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name)
|
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil)
|
||||||
}
|
}
|
||||||
itemType := arrayParam.GetItems().GetType()
|
itemType := arrayParam.GetItems().GetType()
|
||||||
var err error
|
var err error
|
||||||
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
|
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err)
|
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -161,7 +162,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
lowLevelParam.ParameterType.Type = "ARRAY"
|
lowLevelParam.ParameterType.Type = "ARRAY"
|
||||||
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
|
||||||
}
|
}
|
||||||
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
|
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
|
||||||
|
|
||||||
@@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
// Handle scalar types based on their defined type.
|
// Handle scalar types based on their defined type.
|
||||||
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
|
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
|
||||||
}
|
}
|
||||||
lowLevelParam.ParameterType.Type = bqType
|
lowLevelParam.ParameterType.Type = bqType
|
||||||
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
|
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
|
||||||
@@ -190,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if source.BigQuerySession() != nil {
|
if source.BigQuerySession() != nil {
|
||||||
session, err := source.BigQuerySession()(ctx)
|
session, err := source.BigQuerySession()(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
if session != nil {
|
if session != nil {
|
||||||
// Add session ID to the connection properties for subsequent calls.
|
// Add session ID to the connection properties for subsequent calls.
|
||||||
@@ -200,17 +201,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps)
|
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
return nil, util.ProcessGcpError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
statementType := dryRunJob.Statistics.Query.StatementType
|
statementType := dryRunJob.Statistics.Query.StatementType
|
||||||
|
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
|
||||||
return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ package bigtable
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"cloud.google.com/go/bigtable"
|
"cloud.google.com/go/bigtable"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,24 +98,28 @@ type Tool struct {
|
|||||||
func (t Tool) ToConfig() tools.ToolConfig {
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
return nil, util.NewAgentError("unable to extract template params", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||||
}
|
}
|
||||||
return source.RunSQL(ctx, newStatement, t.Parameters, newParams)
|
|
||||||
|
resp, err := source.RunSQL(ctx, newStatement, t.Parameters, newParams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,12 +17,14 @@ package cassandracql
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,23 +109,27 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Invoke implements tools.Tool.
|
// Invoke implements tools.Tool.
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
return nil, util.NewAgentError("unable to extract template params", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||||
}
|
}
|
||||||
return source.RunSQL(ctx, newStatement, newParams)
|
resp, err := source.RunSQL(ctx, newStatement, newParams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGeneralError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manifest implements tools.Tool.
|
// Manifest implements tools.Tool.
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package clickhouse
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,18 +89,22 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
sql, ok := paramsMap["sql"].(string)
|
sql, ok := paramsMap["sql"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil)
|
||||||
}
|
}
|
||||||
return source.RunSQL(ctx, sql, nil)
|
resp, err := source.RunSQL(ctx, sql, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGeneralError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package clickhouse
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -86,10 +88,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query to list all databases
|
// Query to list all databases
|
||||||
@@ -97,7 +99,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
out, err := source.RunSQL(ctx, query, nil)
|
out, err := source.RunSQL(ctx, query, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.ProcessGeneralError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return out, nil
|
return out, nil
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package clickhouse
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,34 +92,37 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
database, ok := mapParams[databaseKey].(string)
|
database, ok := mapParams[databaseKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", databaseKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", databaseKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query to list all tables in the specified database
|
// Query to list all tables in the specified database
|
||||||
|
// Note: formatting identifier directly is risky if input is untrusted, but standard for this tool structure.
|
||||||
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
||||||
|
|
||||||
out, err := source.RunSQL(ctx, query, nil)
|
out, err := source.RunSQL(ctx, query, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.ProcessGeneralError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, ok := out.([]any)
|
res, ok := out.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unable to convert result to list")
|
return nil, util.NewClientServerError("unable to convert result to list", http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tables []map[string]any
|
var tables []map[string]any
|
||||||
for _, item := range res {
|
for _, item := range res {
|
||||||
tableMap, ok := item.(map[string]any)
|
tableMap, ok := item.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unexpected type in result: got %T, want map[string]any", item)
|
return nil, util.NewClientServerError(fmt.Sprintf("unexpected type in result: got %T, want map[string]any", item), http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
tableMap["database"] = database
|
tableMap["database"] = database
|
||||||
tables = append(tables, tableMap)
|
tables = append(tables, tableMap)
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package clickhouse
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,24 +90,28 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract template params: %w", err)
|
return nil, util.NewAgentError("unable to extract template params", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to extract standard params: %w", err)
|
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return source.RunSQL(ctx, newStatement, newParams)
|
resp, err := source.RunSQL(ctx, newStatement, newParams)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGeneralError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -18,11 +18,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -119,17 +121,16 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the tool logic
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
query, ok := paramsMap["query"].(string)
|
query, ok := paramsMap["query"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
return nil, util.NewAgentError("query parameter not found or not a string", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the access token if provided
|
// Parse the access token if provided
|
||||||
@@ -138,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
var err error
|
var err error
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,9 +155,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
bodyBytes, err := json.Marshal(payload)
|
bodyBytes, err := json.Marshal(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
return source.RunQuery(ctx, tokenStr, bodyBytes)
|
|
||||||
|
resp, err := source.RunQuery(ctx, tokenStr, bodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ package fhirfetchpage
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -93,24 +95,31 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url, ok := params.AsMap()[pageURLKey].(string)
|
url, ok := params.AsMap()[pageURLKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", pageURLKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenStr string
|
var tokenStr string
|
||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
|
||||||
|
resp, err := source.FHIRFetchPage(ctx, url, tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ package fhirpatienteverything
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
)
|
)
|
||||||
@@ -116,26 +118,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
// ValidateAndFetchStoreID usually returns input validation errors
|
||||||
|
return nil, util.NewAgentError("failed to validate store ID", err)
|
||||||
}
|
}
|
||||||
patientID, ok := params.AsMap()[patientIDKey].(string)
|
patientID, ok := params.AsMap()[patientIDKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", patientIDKey), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var tokenStr string
|
var tokenStr string
|
||||||
if source.UseClientAuthorization() {
|
if source.UseClientAuthorization() {
|
||||||
tokenStr, err = accessToken.ParseBearerToken()
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,11 +146,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||||
types, ok := val.([]any)
|
types, ok := val.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string array", typeFilterKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string array", typeFilterKey), nil)
|
||||||
}
|
}
|
||||||
typeFilterSlice, err := parameters.ConvertAnySliceToTyped(types, "string")
|
typeFilterSlice, err := parameters.ConvertAnySliceToTyped(types, "string")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("can't convert '%s' to array of strings: %s", typeFilterKey, err)
|
return nil, util.NewAgentError(fmt.Sprintf("can't convert '%s' to array of strings: %s", typeFilterKey, err), err)
|
||||||
}
|
}
|
||||||
if len(typeFilterSlice.([]string)) != 0 {
|
if len(typeFilterSlice.([]string)) != 0 {
|
||||||
opts = append(opts, googleapi.QueryParameter("_type", strings.Join(typeFilterSlice.([]string), ",")))
|
opts = append(opts, googleapi.QueryParameter("_type", strings.Join(typeFilterSlice.([]string), ",")))
|
||||||
@@ -156,13 +159,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if since, ok := params.AsMap()[sinceFilterKey]; ok {
|
if since, ok := params.AsMap()[sinceFilterKey]; ok {
|
||||||
sinceStr, ok := since.(string)
|
sinceStr, ok := since.(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", sinceFilterKey)
|
return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", sinceFilterKey), nil)
|
||||||
}
|
}
|
||||||
if sinceStr != "" {
|
if sinceStr != "" {
|
||||||
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
|
||||||
|
resp, err := source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.ProcessGcpError(err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user