mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-13 16:45:01 -05:00
Compare commits
31 Commits
adk-python
...
pr/dumians
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
251ef22839 | ||
|
|
8dc4bd7dd6 | ||
|
|
a00d0edcd5 | ||
|
|
d1eb1799a0 | ||
|
|
2b81d6099a | ||
|
|
53865b6e21 | ||
|
|
6843b46328 | ||
|
|
05152f732d | ||
|
|
a48101b3c5 | ||
|
|
418d6d791e | ||
|
|
452d686750 | ||
|
|
8fb74263a7 | ||
|
|
09f3bc7959 | ||
|
|
7970b8787e | ||
|
|
0b7a86ae58 | ||
|
|
0797142103 | ||
|
|
f26750e834 | ||
|
|
c46d7d6fa0 | ||
|
|
5e2034d146 | ||
|
|
e2272ccdbc | ||
|
|
97f68129f5 | ||
|
|
fea96fed03 | ||
|
|
195767bdcd | ||
|
|
c5524d32f5 | ||
|
|
e1739abd81 | ||
|
|
478a0bdb59 | ||
|
|
2d341acaa6 | ||
|
|
f032389a07 | ||
|
|
32610d71a3 | ||
|
|
32cb4db712 | ||
|
|
1fdd99a9b6 |
@@ -354,6 +354,30 @@ steps:
|
||||
postgressql \
|
||||
postgresexecutesql
|
||||
|
||||
- id: "cockroachdb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "COCKROACHDB_DATABASE=$_DATABASE_NAME"
|
||||
- "COCKROACHDB_PORT=$_COCKROACHDB_PORT"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["COCKROACHDB_USER", "COCKROACHDB_HOST","CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"CockroachDB" \
|
||||
cockroachdb \
|
||||
cockroachdbsql \
|
||||
cockroachdbexecutesql \
|
||||
cockroachdblisttables \
|
||||
cockroachdblistschemas
|
||||
|
||||
- id: "spanner"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
@@ -919,7 +943,7 @@ steps:
|
||||
# Install the C compiler and Oracle SDK headers needed for cgo
|
||||
dnf install -y gcc oracle-instantclient-devel
|
||||
# Install Go
|
||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
|
||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.5.linux-amd64.tar.gz"
|
||||
tar -C /usr/local -xzf go.tar.gz
|
||||
export PATH="/usr/local/go/bin:$$PATH"
|
||||
|
||||
@@ -1129,6 +1153,11 @@ availableSecrets:
|
||||
env: MARIADB_HOST
|
||||
- versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest
|
||||
env: MONGODB_URI
|
||||
- versionName: projects/$PROJECT_ID/secrets/cockroachdb_user/versions/latest
|
||||
env: COCKROACHDB_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cockroachdb_host/versions/latest
|
||||
env: COCKROACHDB_HOST
|
||||
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
@@ -1189,6 +1218,9 @@ substitutions:
|
||||
_SINGLESTORE_PORT: "3308"
|
||||
_SINGLESTORE_DATABASE: "singlestore"
|
||||
_SINGLESTORE_USER: "root"
|
||||
_COCKROACHDB_HOST: 127.0.0.1
|
||||
_COCKROACHDB_PORT: "26257"
|
||||
_COCKROACHDB_USER: "root"
|
||||
_MARIADB_PORT: "3307"
|
||||
_MARIADB_DATABASE: test_database
|
||||
_SNOWFLAKE_DATABASE: "test"
|
||||
|
||||
1
.github/release-please.yml
vendored
1
.github/release-please.yml
vendored
@@ -37,6 +37,7 @@ extraFiles: [
|
||||
"docs/en/how-to/connect-ide/postgres_mcp.md",
|
||||
"docs/en/how-to/connect-ide/neo4j_mcp.md",
|
||||
"docs/en/how-to/connect-ide/sqlite_mcp.md",
|
||||
"docs/en/how-to/connect-ide/oracle_mcp.md",
|
||||
"gemini-extension.json",
|
||||
{
|
||||
"type": "json",
|
||||
|
||||
@@ -51,6 +51,10 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick
|
||||
# Add a new version block here before every release
|
||||
# The order of versions in this file is mirrored into the dropdown
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.27.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.27.0/"
|
||||
|
||||
[[params.versions]]
|
||||
version = "v0.26.0"
|
||||
url = "https://googleapis.github.io/genai-toolbox/v0.26.0/"
|
||||
|
||||
26
CHANGELOG.md
26
CHANGELOG.md
@@ -1,5 +1,31 @@
|
||||
# Changelog
|
||||
|
||||
## [0.27.0](https://github.com/googleapis/genai-toolbox/compare/v0.26.0...v0.27.0) (2026-02-12)
|
||||
|
||||
|
||||
### ⚠ BREAKING CHANGES
|
||||
|
||||
* Update configuration file v2 ([#2369](https://github.com/googleapis/genai-toolbox/issues/2369))([293c1d6](https://github.com/googleapis/genai-toolbox/commit/293c1d6889c39807855ba5e01d4c13ba2a4c50ce))
|
||||
* Update/add detailed telemetry for mcp endpoint compliant with OTEL semantic convention ([#1987](https://github.com/googleapis/genai-toolbox/issues/1987)) ([478a0bd](https://github.com/googleapis/genai-toolbox/commit/478a0bdb59288c1213f83862f95a698b4c2c0aab))
|
||||
|
||||
### Features
|
||||
|
||||
* **cli/invoke:** Add support for direct tool invocation from CLI ([#2353](https://github.com/googleapis/genai-toolbox/issues/2353)) ([6e49ba4](https://github.com/googleapis/genai-toolbox/commit/6e49ba436ef2390c13feaf902b29f5907acffb57))
|
||||
* **cli/skills:** Add support for generating agent skills from toolset ([#2392](https://github.com/googleapis/genai-toolbox/issues/2392)) ([80ef346](https://github.com/googleapis/genai-toolbox/commit/80ef34621453b77bdf6a6016c354f102a17ada04))
|
||||
* **cloud-logging-admin:** Add source, tools, integration test and docs ([#2137](https://github.com/googleapis/genai-toolbox/issues/2137)) ([252fc30](https://github.com/googleapis/genai-toolbox/commit/252fc3091af10d25d8d7af7e047b5ac87a5dd041))
|
||||
* **cockroachdb:** Add CockroachDB integration with cockroach-go ([#2006](https://github.com/googleapis/genai-toolbox/issues/2006)) ([1fdd99a](https://github.com/googleapis/genai-toolbox/commit/1fdd99a9b609a5e906acce414226ff44d75d5975))
|
||||
* **prebuiltconfigs/alloydb-omni:** Implement Alloydb omni dataplane tools ([#2340](https://github.com/googleapis/genai-toolbox/issues/2340)) ([e995349](https://github.com/googleapis/genai-toolbox/commit/e995349ea0756c700d188b8f04e9459121219f0c))
|
||||
* **server:** Add Tool call error categories ([#2387](https://github.com/googleapis/genai-toolbox/issues/2387)) ([32cb4db](https://github.com/googleapis/genai-toolbox/commit/32cb4db712d27579c1bf29e61cbd0bed02286c28))
|
||||
* **tools/looker:** support `looker-validate-project` tool ([#2430](https://github.com/googleapis/genai-toolbox/issues/2430)) ([a15a128](https://github.com/googleapis/genai-toolbox/commit/a15a12873f936b0102aeb9500cc3bcd71bb38c34))
|
||||
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* **dataplex:** Capture GCP HTTP errors in MCP Toolbox ([#2347](https://github.com/googleapis/genai-toolbox/issues/2347)) ([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15))
|
||||
* **sources/cockroachdb:** Update kind to type ([#2465](https://github.com/googleapis/genai-toolbox/issues/2465)) ([2d341ac](https://github.com/googleapis/genai-toolbox/commit/2d341acaa61c3c1fe908fceee8afbd90fb646d3a))
|
||||
* Surface Dataplex API errors in MCP results ([#2347](https://github.com/googleapis/genai-toolbox/pull/2347))([1d7c498](https://github.com/googleapis/genai-toolbox/commit/1d7c4981164c34b4d7bc8edecfd449f57ad11e15))
|
||||
|
||||
## [0.26.0](https://github.com/googleapis/genai-toolbox/compare/v0.25.0...v0.26.0) (2026-01-22)
|
||||
|
||||
|
||||
|
||||
14
README.md
14
README.md
@@ -142,7 +142,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.26.0
|
||||
> export VERSION=0.27.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -155,7 +155,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.26.0
|
||||
> export VERSION=0.27.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -168,7 +168,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```sh
|
||||
> # see releases page for other versions
|
||||
> export VERSION=0.26.0
|
||||
> export VERSION=0.27.0
|
||||
> curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
> chmod +x toolbox
|
||||
> ```
|
||||
@@ -181,7 +181,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```cmd
|
||||
> :: see releases page for other versions
|
||||
> set VERSION=0.26.0
|
||||
> set VERSION=0.27.0
|
||||
> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -193,7 +193,7 @@ To install Toolbox as a binary:
|
||||
>
|
||||
> ```powershell
|
||||
> # see releases page for other versions
|
||||
> $VERSION = "0.26.0"
|
||||
> $VERSION = "0.27.0"
|
||||
> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
> ```
|
||||
>
|
||||
@@ -206,7 +206,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.26.0
|
||||
export VERSION=0.27.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -230,7 +230,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.26.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.27.0
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
257
cmd/internal/imports.go
Normal file
257
cmd/internal/imports.go
Normal file
@@ -0,0 +1,257 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
// Import prompt packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/prompts/custom"
|
||||
|
||||
// Import tool packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreateuser"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetcluster"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbgetuser"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistclusters"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydblistusers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygetdatasetinfo"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerygettableinfo"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylistdatasetids"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerylisttableids"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdataset"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistlognames"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminlistresourcetypes"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudloggingadmin/cloudloggingadminquerylogs"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/elasticsearch/elasticsearchesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firebird/firebirdsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreadddocuments"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoreupdatedocument"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdevmode"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergenerateembedurl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiondatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnections"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectionschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontablecolumns"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetconnectiontables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdashboards"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetdimensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetexplores"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetfilters"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetlooks"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmeasures"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetmodels"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetparameters"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojectfiles"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookergetprojects"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthanalyze"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthpulse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerhealthvacuum"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakedashboard"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookermakelook"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerqueryurl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrundashboard"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerrunlook"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerupdateprojectfile"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookervalidateproject"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mindsdb/mindsdbsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbaggregate"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeletemany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbdeleteone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfind"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbfindone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertmany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbinsertone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdatemany"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mongodb/mongodbupdateone"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablesmissinguniqueindexes"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresreplicationstats"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoresql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakeexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/snowflake/snowflakesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlistgraphs"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerlisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/valkey"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/yugabytedbsql"
|
||||
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudloggingadmin"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/elasticsearch"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/firebird"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/looker"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mindsdb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mongodb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/oceanbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/singlestore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/snowflake"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/trino"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/valkey"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb"
|
||||
)
|
||||
@@ -18,37 +18,15 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// RootCommand defines the interface for required by invoke subcommand.
|
||||
// This allows subcommands to access shared resources and functionality without
|
||||
// direct coupling to the root command's implementation.
|
||||
type RootCommand interface {
|
||||
// Config returns a copy of the current server configuration.
|
||||
Config() server.ServerConfig
|
||||
|
||||
// Out returns the writer used for standard output.
|
||||
Out() io.Writer
|
||||
|
||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
||||
LoadConfig(ctx context.Context) error
|
||||
|
||||
// Setup initializes the runtime environment, including logging and telemetry.
|
||||
// It returns the updated context and a shutdown function to be called when finished.
|
||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
||||
|
||||
// Logger returns the logger instance.
|
||||
Logger() log.Logger
|
||||
}
|
||||
|
||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
||||
func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "invoke <tool-name> [params]",
|
||||
Short: "Execute a tool directly",
|
||||
@@ -58,17 +36,17 @@ Example:
|
||||
toolbox invoke my-tool '{"param1": "value1"}'`,
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: func(c *cobra.Command, args []string) error {
|
||||
return runInvoke(c, args, rootCmd)
|
||||
return runInvoke(c, args, opts)
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
func runInvoke(cmd *cobra.Command, args []string, opts *internal.ToolboxOptions) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
ctx, shutdown, err := rootCmd.Setup(ctx)
|
||||
ctx, shutdown, err := opts.Setup(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -76,16 +54,16 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
_ = shutdown(ctx)
|
||||
}()
|
||||
|
||||
// Load and merge tool configurations
|
||||
if err := rootCmd.LoadConfig(ctx); err != nil {
|
||||
_, err = opts.LoadConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initialize Resources
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, rootCmd.Config())
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("failed to initialize resources: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -96,7 +74,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
errMsg := fmt.Errorf("tool %q not found", toolName)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -109,7 +87,7 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
if paramsInput != "" {
|
||||
if err := json.Unmarshal([]byte(paramsInput), ¶ms); err != nil {
|
||||
errMsg := fmt.Errorf("params must be a valid JSON string: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
}
|
||||
@@ -117,14 +95,14 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
parsedParams, err := parameters.ParseParams(tool.GetParameters(), params, nil)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("invalid parameters: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
parsedParams, err = tool.EmbedParams(ctx, parsedParams, resourceMgr.GetEmbeddingModelMap())
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error embedding parameters: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -132,19 +110,19 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
requiresAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("failed to check auth requirements: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
if requiresAuth {
|
||||
errMsg := fmt.Errorf("client authorization is not supported")
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
result, err := tool.Invoke(ctx, resourceMgr, parsedParams, "")
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("tool execution failed: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -152,10 +130,10 @@ func runInvoke(cmd *cobra.Command, args []string, rootCmd RootCommand) error {
|
||||
output, err := json.MarshalIndent(result, "", " ")
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("failed to marshal result: %w", err)
|
||||
rootCmd.Logger().ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
fmt.Fprintln(rootCmd.Out(), string(output))
|
||||
fmt.Fprintln(opts.IOStreams.Out, string(output))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -12,16 +12,38 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd
|
||||
package invoke
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func invokeCommand(args []string) (string, error) {
|
||||
parentCmd := &cobra.Command{Use: "toolbox"}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf))
|
||||
internal.PersistentFlags(parentCmd, opts)
|
||||
|
||||
cmd := NewCommand(opts)
|
||||
parentCmd.AddCommand(cmd)
|
||||
parentCmd.SetArgs(args)
|
||||
|
||||
err := parentCmd.Execute()
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
func TestInvokeTool(t *testing.T) {
|
||||
// Create a temporary tools file
|
||||
tmpDir := t.TempDir()
|
||||
@@ -86,7 +108,7 @@ tools:
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
_, got, err := invokeCommandWithContext(context.Background(), tc.args)
|
||||
got, err := invokeCommand(tc.args)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Fatalf("got error %v, wantErr %v", err, tc.wantErr)
|
||||
}
|
||||
@@ -121,7 +143,7 @@ tools:
|
||||
}
|
||||
|
||||
args := []string{"invoke", "bq-tool", "--tools-file", toolsFilePath}
|
||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
||||
_, err := invokeCommand(args)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for tool requiring client auth, but got nil")
|
||||
}
|
||||
251
cmd/internal/options.go
Normal file
251
cmd/internal/options.go
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/telemetry"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
type IOStreams struct {
|
||||
In io.Reader
|
||||
Out io.Writer
|
||||
ErrOut io.Writer
|
||||
}
|
||||
|
||||
// ToolboxOptions holds dependencies shared by all commands.
|
||||
type ToolboxOptions struct {
|
||||
IOStreams IOStreams
|
||||
Logger log.Logger
|
||||
Cfg server.ServerConfig
|
||||
ToolsFile string
|
||||
ToolsFiles []string
|
||||
ToolsFolder string
|
||||
PrebuiltConfigs []string
|
||||
}
|
||||
|
||||
// Option defines a function that modifies the ToolboxOptions struct.
|
||||
type Option func(*ToolboxOptions)
|
||||
|
||||
// NewToolboxOptions creates a new instance with defaults, then applies any
|
||||
// provided options.
|
||||
func NewToolboxOptions(opts ...Option) *ToolboxOptions {
|
||||
o := &ToolboxOptions{
|
||||
IOStreams: IOStreams{
|
||||
In: os.Stdin,
|
||||
Out: os.Stdout,
|
||||
ErrOut: os.Stderr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// Apply allows you to update an EXISTING ToolboxOptions instance.
|
||||
// This is useful for "late binding".
|
||||
func (o *ToolboxOptions) Apply(opts ...Option) {
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
}
|
||||
|
||||
// WithIOStreams updates the IO streams.
|
||||
func WithIOStreams(out, err io.Writer) Option {
|
||||
return func(o *ToolboxOptions) {
|
||||
o.IOStreams.Out = out
|
||||
o.IOStreams.ErrOut = err
|
||||
}
|
||||
}
|
||||
|
||||
// Setup create logger and telemetry instrumentations.
|
||||
func (opts *ToolboxOptions) Setup(ctx context.Context) (context.Context, func(context.Context) error, error) {
|
||||
// If stdio, set logger's out stream (usually DEBUG and INFO logs) to
|
||||
// errStream
|
||||
loggerOut := opts.IOStreams.Out
|
||||
if opts.Cfg.Stdio {
|
||||
loggerOut = opts.IOStreams.ErrOut
|
||||
}
|
||||
|
||||
// Handle logger separately from config
|
||||
logger, err := log.NewLogger(opts.Cfg.LoggingFormat.String(), opts.Cfg.LogLevel.String(), loggerOut, opts.IOStreams.ErrOut)
|
||||
if err != nil {
|
||||
return ctx, nil, fmt.Errorf("unable to initialize logger: %w", err)
|
||||
}
|
||||
|
||||
ctx = util.WithLogger(ctx, logger)
|
||||
opts.Logger = logger
|
||||
|
||||
// Set up OpenTelemetry
|
||||
otelShutdown, err := telemetry.SetupOTel(ctx, opts.Cfg.Version, opts.Cfg.TelemetryOTLP, opts.Cfg.TelemetryGCP, opts.Cfg.TelemetryServiceName)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error setting up OpenTelemetry: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return ctx, nil, errMsg
|
||||
}
|
||||
|
||||
shutdownFunc := func(ctx context.Context) error {
|
||||
err := otelShutdown(ctx)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error shutting down OpenTelemetry: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
instrumentation, err := telemetry.CreateTelemetryInstrumentation(opts.Cfg.Version)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to create telemetry instrumentation: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return ctx, shutdownFunc, errMsg
|
||||
}
|
||||
|
||||
ctx = util.WithInstrumentation(ctx, instrumentation)
|
||||
|
||||
return ctx, shutdownFunc, nil
|
||||
}
|
||||
|
||||
// LoadConfig checks and merge files that should be loaded into the server
|
||||
func (opts *ToolboxOptions) LoadConfig(ctx context.Context) (bool, error) {
|
||||
// Determine if Custom Files should be loaded
|
||||
// Check for explicit custom flags
|
||||
isCustomConfigured := opts.ToolsFile != "" || len(opts.ToolsFiles) > 0 || opts.ToolsFolder != ""
|
||||
|
||||
// Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags)
|
||||
useDefaultToolsFile := len(opts.PrebuiltConfigs) == 0 && !isCustomConfigured
|
||||
|
||||
if useDefaultToolsFile {
|
||||
opts.ToolsFile = "tools.yaml"
|
||||
isCustomConfigured = true
|
||||
}
|
||||
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return isCustomConfigured, err
|
||||
}
|
||||
|
||||
var allToolsFiles []ToolsFile
|
||||
|
||||
// Load Prebuilt Configuration
|
||||
|
||||
if len(opts.PrebuiltConfigs) > 0 {
|
||||
slices.Sort(opts.PrebuiltConfigs)
|
||||
sourcesList := strings.Join(opts.PrebuiltConfigs, ", ")
|
||||
logMsg := fmt.Sprintf("Using prebuilt tool configurations for: %s", sourcesList)
|
||||
logger.InfoContext(ctx, logMsg)
|
||||
|
||||
for _, configName := range opts.PrebuiltConfigs {
|
||||
buf, err := prebuiltconfigs.Get(configName)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
return isCustomConfigured, err
|
||||
}
|
||||
|
||||
// Parse into ToolsFile struct
|
||||
parsed, err := parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to parse prebuilt tool configuration for '%s': %w", configName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return isCustomConfigured, errMsg
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, parsed)
|
||||
}
|
||||
}
|
||||
|
||||
// Load Custom Configurations
|
||||
if isCustomConfigured {
|
||||
// Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder)
|
||||
if (opts.ToolsFile != "" && len(opts.ToolsFiles) > 0) ||
|
||||
(opts.ToolsFile != "" && opts.ToolsFolder != "") ||
|
||||
(len(opts.ToolsFiles) > 0 && opts.ToolsFolder != "") {
|
||||
errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously")
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return isCustomConfigured, errMsg
|
||||
}
|
||||
|
||||
var customTools ToolsFile
|
||||
var err error
|
||||
|
||||
if len(opts.ToolsFiles) > 0 {
|
||||
// Use tools-files
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(opts.ToolsFiles)))
|
||||
customTools, err = LoadAndMergeToolsFiles(ctx, opts.ToolsFiles)
|
||||
} else if opts.ToolsFolder != "" {
|
||||
// Use tools-folder
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", opts.ToolsFolder))
|
||||
customTools, err = LoadAndMergeToolsFolder(ctx, opts.ToolsFolder)
|
||||
} else {
|
||||
// Use single file (tools-file or default `tools.yaml`)
|
||||
buf, readFileErr := os.ReadFile(opts.ToolsFile)
|
||||
if readFileErr != nil {
|
||||
errMsg := fmt.Errorf("unable to read tool file at %q: %w", opts.ToolsFile, readFileErr)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
return isCustomConfigured, errMsg
|
||||
}
|
||||
customTools, err = parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse tool file at %q: %w", opts.ToolsFile, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
return isCustomConfigured, err
|
||||
}
|
||||
allToolsFiles = append(allToolsFiles, customTools)
|
||||
}
|
||||
|
||||
// Modify version string based on loaded configurations
|
||||
if len(opts.PrebuiltConfigs) > 0 {
|
||||
tag := "prebuilt"
|
||||
if isCustomConfigured {
|
||||
tag = "custom"
|
||||
}
|
||||
// prebuiltConfigs is already sorted above
|
||||
for _, configName := range opts.PrebuiltConfigs {
|
||||
opts.Cfg.Version += fmt.Sprintf("+%s.%s", tag, configName)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge Everything
|
||||
// This will error if custom tools collide with prebuilt tools
|
||||
finalToolsFile, err := mergeToolsFiles(allToolsFiles...)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, err.Error())
|
||||
return isCustomConfigured, err
|
||||
}
|
||||
|
||||
opts.Cfg.SourceConfigs = finalToolsFile.Sources
|
||||
opts.Cfg.AuthServiceConfigs = finalToolsFile.AuthServices
|
||||
opts.Cfg.EmbeddingModelConfigs = finalToolsFile.EmbeddingModels
|
||||
opts.Cfg.ToolConfigs = finalToolsFile.Tools
|
||||
opts.Cfg.ToolsetConfigs = finalToolsFile.Toolsets
|
||||
opts.Cfg.PromptConfigs = finalToolsFile.Prompts
|
||||
|
||||
return isCustomConfigured, nil
|
||||
}
|
||||
@@ -12,57 +12,38 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd
|
||||
package internal
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func TestCommandOptions(t *testing.T) {
|
||||
func TestToolboxOptions(t *testing.T) {
|
||||
w := io.Discard
|
||||
tcs := []struct {
|
||||
desc string
|
||||
isValid func(*Command) error
|
||||
isValid func(*ToolboxOptions) error
|
||||
option Option
|
||||
}{
|
||||
{
|
||||
desc: "with logger",
|
||||
isValid: func(c *Command) error {
|
||||
if c.outStream != w || c.errStream != w {
|
||||
isValid: func(o *ToolboxOptions) error {
|
||||
if o.IOStreams.Out != w || o.IOStreams.ErrOut != w {
|
||||
return errors.New("loggers do not match")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
option: WithStreams(w, w),
|
||||
option: WithIOStreams(w, w),
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got, err := invokeProxyWithOption(tc.option)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := NewToolboxOptions(tc.option)
|
||||
if err := tc.isValid(got); err != nil {
|
||||
t.Errorf("option did not initialize command correctly: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func invokeProxyWithOption(o Option) (*Command, error) {
|
||||
c := NewCommand(o)
|
||||
// Keep the test output quiet
|
||||
c.SilenceUsage = true
|
||||
c.SilenceErrors = true
|
||||
// Disable execute behavior
|
||||
c.RunE = func(*cobra.Command, []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := c.Execute()
|
||||
return c, err
|
||||
}
|
||||
46
cmd/internal/persistent_flags.go
Normal file
46
cmd/internal/persistent_flags.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// PersistentFlags sets up flags that are available for all commands and
|
||||
// subcommands
|
||||
// It is also used to set up persistent flags during subcommand unit tests
|
||||
func PersistentFlags(parentCmd *cobra.Command, opts *ToolboxOptions) {
|
||||
persistentFlags := parentCmd.PersistentFlags()
|
||||
|
||||
persistentFlags.StringVar(&opts.ToolsFile, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.")
|
||||
persistentFlags.StringSliceVar(&opts.ToolsFiles, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.")
|
||||
persistentFlags.StringVar(&opts.ToolsFolder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.")
|
||||
persistentFlags.Var(&opts.Cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.")
|
||||
persistentFlags.Var(&opts.Cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.")
|
||||
persistentFlags.BoolVar(&opts.Cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.")
|
||||
persistentFlags.StringVar(&opts.Cfg.TelemetryOTLP, "telemetry-otlp", "", "Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318')")
|
||||
persistentFlags.StringVar(&opts.Cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.")
|
||||
// Fetch prebuilt tools sources to customize the help description
|
||||
prebuiltHelp := fmt.Sprintf(
|
||||
"Use a prebuilt tool configuration by source type. Allowed: '%s'. Can be specified multiple times.",
|
||||
strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"),
|
||||
)
|
||||
persistentFlags.StringSliceVar(&opts.PrebuiltConfigs, "prebuilt", []string{}, prebuiltHelp)
|
||||
persistentFlags.StringSliceVar(&opts.Cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -30,28 +30,9 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// RootCommand defines the interface for required by skills-generate subcommand.
|
||||
// This allows subcommands to access shared resources and functionality without
|
||||
// direct coupling to the root command's implementation.
|
||||
type RootCommand interface {
|
||||
// Config returns a copy of the current server configuration.
|
||||
Config() server.ServerConfig
|
||||
|
||||
// LoadConfig loads and merges the configuration from files, folders, and prebuilts.
|
||||
LoadConfig(ctx context.Context) error
|
||||
|
||||
// Setup initializes the runtime environment, including logging and telemetry.
|
||||
// It returns the updated context and a shutdown function to be called when finished.
|
||||
Setup(ctx context.Context) (context.Context, func(context.Context) error, error)
|
||||
|
||||
// Logger returns the logger instance.
|
||||
Logger() log.Logger
|
||||
}
|
||||
|
||||
// Command is the command for generating skills.
|
||||
type Command struct {
|
||||
// skillsCmd is the command for generating skills.
|
||||
type skillsCmd struct {
|
||||
*cobra.Command
|
||||
rootCmd RootCommand
|
||||
name string
|
||||
description string
|
||||
toolset string
|
||||
@@ -59,15 +40,13 @@ type Command struct {
|
||||
}
|
||||
|
||||
// NewCommand creates a new Command.
|
||||
func NewCommand(rootCmd RootCommand) *cobra.Command {
|
||||
cmd := &Command{
|
||||
rootCmd: rootCmd,
|
||||
}
|
||||
func NewCommand(opts *internal.ToolboxOptions) *cobra.Command {
|
||||
cmd := &skillsCmd{}
|
||||
cmd.Command = &cobra.Command{
|
||||
Use: "skills-generate",
|
||||
Short: "Generate skills from tool configurations",
|
||||
RunE: func(c *cobra.Command, args []string) error {
|
||||
return cmd.run(c)
|
||||
return run(cmd, opts)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -81,11 +60,11 @@ func NewCommand(rootCmd RootCommand) *cobra.Command {
|
||||
return cmd.Command
|
||||
}
|
||||
|
||||
func (c *Command) run(cmd *cobra.Command) error {
|
||||
func run(cmd *skillsCmd, opts *internal.ToolboxOptions) error {
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
ctx, shutdown, err := c.rootCmd.Setup(ctx)
|
||||
ctx, shutdown, err := opts.Setup(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -93,39 +72,37 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
_ = shutdown(ctx)
|
||||
}()
|
||||
|
||||
logger := c.rootCmd.Logger()
|
||||
|
||||
// Load and merge tool configurations
|
||||
if err := c.rootCmd.LoadConfig(ctx); err != nil {
|
||||
_, err = opts.LoadConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(c.outputDir, 0755); err != nil {
|
||||
if err := os.MkdirAll(cmd.outputDir, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating output directory: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", c.name))
|
||||
opts.Logger.InfoContext(ctx, fmt.Sprintf("Generating skill '%s'...", cmd.name))
|
||||
|
||||
// Initialize toolbox and collect tools
|
||||
allTools, err := c.collectTools(ctx)
|
||||
allTools, err := cmd.collectTools(ctx, opts)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error collecting tools: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
if len(allTools) == 0 {
|
||||
logger.InfoContext(ctx, "No tools found to generate.")
|
||||
opts.Logger.InfoContext(ctx, "No tools found to generate.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Generate the combined skill directory
|
||||
skillPath := filepath.Join(c.outputDir, c.name)
|
||||
skillPath := filepath.Join(cmd.outputDir, cmd.name)
|
||||
if err := os.MkdirAll(skillPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating skill directory: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -133,7 +110,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
assetsPath := filepath.Join(skillPath, "assets")
|
||||
if err := os.MkdirAll(assetsPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating assets dir: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -141,7 +118,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
scriptsPath := filepath.Join(skillPath, "scripts")
|
||||
if err := os.MkdirAll(scriptsPath, 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error creating scripts dir: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -154,10 +131,10 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
|
||||
for _, toolName := range toolNames {
|
||||
// Generate YAML config in asset directory
|
||||
minimizedContent, err := generateToolConfigYAML(c.rootCmd.Config(), toolName)
|
||||
minimizedContent, err := generateToolConfigYAML(opts.Cfg, toolName)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating filtered config for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
@@ -166,7 +143,7 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
destPath := filepath.Join(assetsPath, specificToolsFileName)
|
||||
if err := os.WriteFile(destPath, minimizedContent, 0644); err != nil {
|
||||
errMsg := fmt.Errorf("error writing filtered config for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
}
|
||||
@@ -175,40 +152,40 @@ func (c *Command) run(cmd *cobra.Command) error {
|
||||
scriptContent, err := generateScriptContent(toolName, specificToolsFileName)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating script content for %s: %w", toolName, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
scriptFilename := filepath.Join(scriptsPath, fmt.Sprintf("%s.js", toolName))
|
||||
if err := os.WriteFile(scriptFilename, []byte(scriptContent), 0755); err != nil {
|
||||
errMsg := fmt.Errorf("error writing script %s: %w", scriptFilename, err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
}
|
||||
|
||||
// Generate SKILL.md
|
||||
skillContent, err := generateSkillMarkdown(c.name, c.description, allTools)
|
||||
skillContent, err := generateSkillMarkdown(cmd.name, cmd.description, allTools)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error generating SKILL.md content: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
skillMdPath := filepath.Join(skillPath, "SKILL.md")
|
||||
if err := os.WriteFile(skillMdPath, []byte(skillContent), 0644); err != nil {
|
||||
errMsg := fmt.Errorf("error writing SKILL.md: %w", err)
|
||||
logger.ErrorContext(ctx, errMsg.Error())
|
||||
opts.Logger.ErrorContext(ctx, errMsg.Error())
|
||||
return errMsg
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", c.name, len(allTools)))
|
||||
opts.Logger.InfoContext(ctx, fmt.Sprintf("Successfully generated skill '%s' with %d tools.", cmd.name, len(allTools)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Command) collectTools(ctx context.Context) (map[string]tools.Tool, error) {
|
||||
func (c *skillsCmd) collectTools(ctx context.Context, opts *internal.ToolboxOptions) (map[string]tools.Tool, error) {
|
||||
// Initialize Resources
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, c.rootCmd.Config())
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, opts.Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize resources: %w", err)
|
||||
}
|
||||
@@ -12,17 +12,36 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd
|
||||
package skills
|
||||
|
||||
import (
|
||||
"context"
|
||||
"bytes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/cmd/internal"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func invokeCommand(args []string) (string, error) {
|
||||
parentCmd := &cobra.Command{Use: "toolbox"}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
opts := internal.NewToolboxOptions(internal.WithIOStreams(buf, buf))
|
||||
internal.PersistentFlags(parentCmd, opts)
|
||||
|
||||
cmd := NewCommand(opts)
|
||||
parentCmd.AddCommand(cmd)
|
||||
parentCmd.SetArgs(args)
|
||||
|
||||
err := parentCmd.Execute()
|
||||
return buf.String(), err
|
||||
}
|
||||
|
||||
func TestGenerateSkill(t *testing.T) {
|
||||
// Create a temporary directory for tests
|
||||
tmpDir := t.TempDir()
|
||||
@@ -55,10 +74,7 @@ tools:
|
||||
"--description", "hello tool",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, got, err := invokeCommandWithContext(ctx, args)
|
||||
got, err := invokeCommand(args)
|
||||
if err != nil {
|
||||
t.Fatalf("command failed: %v\nOutput: %s", err, got)
|
||||
}
|
||||
@@ -136,7 +152,7 @@ func TestGenerateSkill_NoConfig(t *testing.T) {
|
||||
"--description", "test",
|
||||
}
|
||||
|
||||
_, _, err := invokeCommandWithContext(context.Background(), args)
|
||||
_, err := invokeCommand(args)
|
||||
if err == nil {
|
||||
t.Fatal("expected command to fail when no configuration is provided and tools.yaml is missing")
|
||||
}
|
||||
@@ -170,7 +186,7 @@ func TestGenerateSkill_MissingArguments(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, got, err := invokeCommandWithContext(context.Background(), tt.args)
|
||||
got, err := invokeCommand(tt.args)
|
||||
if err == nil {
|
||||
t.Fatalf("expected command to fail due to missing arguments, but it succeeded\nOutput: %s", got)
|
||||
}
|
||||
349
cmd/internal/tools_file.go
Normal file
349
cmd/internal/tools_file.go
Normal file
@@ -0,0 +1,349 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
)
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||
}
|
||||
|
||||
// parseEnv replaces environment variables ${ENV_NAME} with their values.
|
||||
// also support ${ENV_NAME:default_value}.
|
||||
func parseEnv(input string) (string, error) {
|
||||
re := regexp.MustCompile(`\$\{(\w+)(:([^}]*))?\}`)
|
||||
|
||||
var err error
|
||||
output := re.ReplaceAllStringFunc(input, func(match string) string {
|
||||
parts := re.FindStringSubmatch(match)
|
||||
|
||||
// extract the variable name
|
||||
variableName := parts[1]
|
||||
if value, found := os.LookupEnv(variableName); found {
|
||||
return value
|
||||
}
|
||||
if len(parts) >= 4 && parts[2] != "" {
|
||||
return parts[3]
|
||||
}
|
||||
err = fmt.Errorf("environment variable not found: %q", variableName)
|
||||
return ""
|
||||
})
|
||||
return output, err
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
var toolsFile ToolsFile
|
||||
// Replace environment variables if found
|
||||
output, err := parseEnv(string(raw))
|
||||
if err != nil {
|
||||
return toolsFile, fmt.Errorf("error parsing environment variables: %s", err)
|
||||
}
|
||||
raw = []byte(output)
|
||||
|
||||
raw, err = convertToolsFile(raw)
|
||||
if err != nil {
|
||||
return toolsFile, fmt.Errorf("error converting tools file: %s", err)
|
||||
}
|
||||
|
||||
// Parse contents
|
||||
toolsFile.Sources, toolsFile.AuthServices, toolsFile.EmbeddingModels, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts, err = server.UnmarshalResourceConfig(ctx, raw)
|
||||
if err != nil {
|
||||
return toolsFile, err
|
||||
}
|
||||
return toolsFile, nil
|
||||
}
|
||||
|
||||
func convertToolsFile(raw []byte) ([]byte, error) {
|
||||
var input yaml.MapSlice
|
||||
decoder := yaml.NewDecoder(bytes.NewReader(raw), yaml.UseOrderedMap())
|
||||
|
||||
// convert to tools file v2
|
||||
var buf bytes.Buffer
|
||||
encoder := yaml.NewEncoder(&buf)
|
||||
|
||||
v1keys := []string{"sources", "authSources", "authServices", "embeddingModels", "tools", "toolsets", "prompts"}
|
||||
for {
|
||||
if err := decoder.Decode(&input); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
for _, item := range input {
|
||||
key, ok := item.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key in input: %v", item.Key)
|
||||
}
|
||||
// check if the key is config file v1's key
|
||||
if slices.Contains(v1keys, key) {
|
||||
// check if value conversion to yaml.MapSlice successfully
|
||||
// fields such as "tools" in toolsets might pass the first check but
|
||||
// fail to convert to MapSlice
|
||||
if slice, ok := item.Value.(yaml.MapSlice); ok {
|
||||
// Deprecated: convert authSources to authServices
|
||||
if key == "authSources" {
|
||||
key = "authServices"
|
||||
}
|
||||
transformed, err := transformDocs(key, slice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// encode per-doc
|
||||
for _, doc := range transformed {
|
||||
if err := encoder.Encode(doc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// invalid input will be ignored
|
||||
// we don't want to throw error here since the config could
|
||||
// be valid but with a different order such as:
|
||||
// ---
|
||||
// tools:
|
||||
// - tool_a
|
||||
// kind: toolsets
|
||||
// ---
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// this doc is already v2, encode to buf
|
||||
if err := encoder.Encode(input); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// transformDocs transforms the configuration file from v1 format to v2
|
||||
// yaml.MapSlice will preserve the order in a map
|
||||
func transformDocs(kind string, input yaml.MapSlice) ([]yaml.MapSlice, error) {
|
||||
var transformed []yaml.MapSlice
|
||||
for _, entry := range input {
|
||||
entryName, ok := entry.Key.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected non-string key for entry in '%s': %v", kind, entry.Key)
|
||||
}
|
||||
entryBody := ProcessValue(entry.Value, kind == "toolsets")
|
||||
|
||||
currentTransformed := yaml.MapSlice{
|
||||
{Key: "kind", Value: kind},
|
||||
{Key: "name", Value: entryName},
|
||||
}
|
||||
|
||||
// Merge the transformed body into our result
|
||||
if bodySlice, ok := entryBody.(yaml.MapSlice); ok {
|
||||
currentTransformed = append(currentTransformed, bodySlice...)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to convert entryBody to MapSlice")
|
||||
}
|
||||
transformed = append(transformed, currentTransformed)
|
||||
}
|
||||
return transformed, nil
|
||||
}
|
||||
|
||||
// ProcessValue recursively looks for MapSlices to rename 'kind' -> 'type'
|
||||
func ProcessValue(v any, isToolset bool) any {
|
||||
switch val := v.(type) {
|
||||
case yaml.MapSlice:
|
||||
// creating a new MapSlice is safer for recursive transformation
|
||||
newVal := make(yaml.MapSlice, len(val))
|
||||
for i, item := range val {
|
||||
// Perform renaming
|
||||
if item.Key == "kind" {
|
||||
item.Key = "type"
|
||||
}
|
||||
// Recursive call for nested values (e.g., nested objects or lists)
|
||||
item.Value = ProcessValue(item.Value, false)
|
||||
newVal[i] = item
|
||||
}
|
||||
return newVal
|
||||
case []any:
|
||||
// Process lists: If it's a toolset top-level list, wrap it.
|
||||
if isToolset {
|
||||
return yaml.MapSlice{{Key: "tools", Value: val}}
|
||||
}
|
||||
// Otherwise, recurse into list items (to catch nested objects)
|
||||
newVal := make([]any, len(val))
|
||||
for i := range val {
|
||||
newVal[i] = ProcessValue(val[i], false)
|
||||
}
|
||||
return newVal
|
||||
default:
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
// mergeToolsFiles merges multiple ToolsFile structs into one.
|
||||
// Detects and raises errors for resource conflicts in sources, authServices, tools, and toolsets.
|
||||
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
|
||||
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
merged := ToolsFile{
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: make(server.PromptConfigs),
|
||||
}
|
||||
|
||||
var conflicts []string
|
||||
|
||||
for fileIndex, file := range files {
|
||||
// Check for conflicts and merge sources
|
||||
for name, source := range file.Sources {
|
||||
if _, exists := merged.Sources[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("source '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.Sources[name] = source
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge authServices
|
||||
for name, authService := range file.AuthServices {
|
||||
if _, exists := merged.AuthServices[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("authService '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.AuthServices[name] = authService
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge embeddingModels
|
||||
for name, em := range file.EmbeddingModels {
|
||||
if _, exists := merged.EmbeddingModels[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.EmbeddingModels[name] = em
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge tools
|
||||
for name, tool := range file.Tools {
|
||||
if _, exists := merged.Tools[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("tool '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.Tools[name] = tool
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge toolsets
|
||||
for name, toolset := range file.Toolsets {
|
||||
if _, exists := merged.Toolsets[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("toolset '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.Toolsets[name] = toolset
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge prompts
|
||||
for name, prompt := range file.Prompts {
|
||||
if _, exists := merged.Prompts[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("prompt '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.Prompts[name] = prompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If conflicts were detected, return an error
|
||||
if len(conflicts) > 0 {
|
||||
return ToolsFile{}, fmt.Errorf("resource conflicts detected:\n - %s\n\nPlease ensure each source, authService, tool, toolset and prompt has a unique name across all files", strings.Join(conflicts, "\n - "))
|
||||
}
|
||||
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// LoadAndMergeToolsFiles loads multiple YAML files and merges them
|
||||
func LoadAndMergeToolsFiles(ctx context.Context, filePaths []string) (ToolsFile, error) {
|
||||
var toolsFiles []ToolsFile
|
||||
|
||||
for _, filePath := range filePaths {
|
||||
buf, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("unable to read tool file at %q: %w", filePath, err)
|
||||
}
|
||||
|
||||
toolsFile, err := parseToolsFile(ctx, buf)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("unable to parse tool file at %q: %w", filePath, err)
|
||||
}
|
||||
|
||||
toolsFiles = append(toolsFiles, toolsFile)
|
||||
}
|
||||
|
||||
mergedFile, err := mergeToolsFiles(toolsFiles...)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("unable to merge tools files: %w", err)
|
||||
}
|
||||
|
||||
return mergedFile, nil
|
||||
}
|
||||
|
||||
// LoadAndMergeToolsFolder loads all YAML files from a directory and merges them
|
||||
func LoadAndMergeToolsFolder(ctx context.Context, folderPath string) (ToolsFile, error) {
|
||||
// Check if directory exists
|
||||
info, err := os.Stat(folderPath)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("unable to access tools folder at %q: %w", folderPath, err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return ToolsFile{}, fmt.Errorf("path %q is not a directory", folderPath)
|
||||
}
|
||||
|
||||
// Find all YAML files in the directory
|
||||
pattern := filepath.Join(folderPath, "*.yaml")
|
||||
yamlFiles, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("error finding YAML files in %q: %w", folderPath, err)
|
||||
}
|
||||
|
||||
// Also find .yml files
|
||||
ymlPattern := filepath.Join(folderPath, "*.yml")
|
||||
ymlFiles, err := filepath.Glob(ymlPattern)
|
||||
if err != nil {
|
||||
return ToolsFile{}, fmt.Errorf("error finding YML files in %q: %w", folderPath, err)
|
||||
}
|
||||
|
||||
// Combine both file lists
|
||||
allFiles := append(yamlFiles, ymlFiles...)
|
||||
|
||||
if len(allFiles) == 0 {
|
||||
return ToolsFile{}, fmt.Errorf("no YAML files found in directory %q", folderPath)
|
||||
}
|
||||
|
||||
// Use existing LoadAndMergeToolsFiles function
|
||||
return LoadAndMergeToolsFiles(ctx, allFiles)
|
||||
}
|
||||
2141
cmd/internal/tools_file_test.go
Normal file
2141
cmd/internal/tools_file_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// Option is a function that configures a Command.
|
||||
type Option func(*Command)
|
||||
|
||||
// WithStreams overrides the default writer.
|
||||
func WithStreams(out, err io.Writer) Option {
|
||||
return func(c *Command) {
|
||||
c.outStream = out
|
||||
c.errStream = err
|
||||
}
|
||||
}
|
||||
886
cmd/root.go
886
cmd/root.go
File diff suppressed because it is too large
Load Diff
1743
cmd/root_test.go
1743
cmd/root_test.go
File diff suppressed because it is too large
Load Diff
@@ -1 +1 @@
|
||||
0.26.0
|
||||
0.27.0
|
||||
|
||||
@@ -234,7 +234,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.26.0\" # x-release-please-version\n",
|
||||
"version = \"0.27.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -109,7 +109,7 @@ To install Toolbox as a binary on Linux (AMD64):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.26.0
|
||||
export VERSION=0.27.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -120,7 +120,7 @@ To install Toolbox as a binary on macOS (Apple Silicon):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.26.0
|
||||
export VERSION=0.27.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -131,7 +131,7 @@ To install Toolbox as a binary on macOS (Intel):
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.26.0
|
||||
export VERSION=0.27.0
|
||||
curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox
|
||||
chmod +x toolbox
|
||||
```
|
||||
@@ -142,7 +142,7 @@ To install Toolbox as a binary on Windows (Command Prompt):
|
||||
|
||||
```cmd
|
||||
:: see releases page for other versions
|
||||
set VERSION=0.26.0
|
||||
set VERSION=0.27.0
|
||||
curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -152,7 +152,7 @@ To install Toolbox as a binary on Windows (PowerShell):
|
||||
|
||||
```powershell
|
||||
# see releases page for other versions
|
||||
$VERSION = "0.26.0"
|
||||
$VERSION = "0.27.0"
|
||||
curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe"
|
||||
```
|
||||
|
||||
@@ -164,7 +164,7 @@ You can also install Toolbox as a container:
|
||||
|
||||
```sh
|
||||
# see releases page for other versions
|
||||
export VERSION=0.26.0
|
||||
export VERSION=0.27.0
|
||||
docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION
|
||||
```
|
||||
|
||||
@@ -183,7 +183,7 @@ To install from source, ensure you have the latest version of
|
||||
[Go installed](https://go.dev/doc/install), and then run the following command:
|
||||
|
||||
```sh
|
||||
go install github.com/googleapis/genai-toolbox@v0.26.0
|
||||
go install github.com/googleapis/genai-toolbox@v0.27.0
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
@@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -100,19 +100,19 @@ After you install Looker in the MCP Store, resources and tools from the server a
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -45,19 +45,19 @@ instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
335
docs/en/how-to/connect-ide/oracle_mcp.md
Normal file
335
docs/en/how-to/connect-ide/oracle_mcp.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
title: "Oracle using MCP"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
Connect your IDE to Oracle using Toolbox.
|
||||
---
|
||||
|
||||
[Model Context Protocol (MCP)](https://modelcontextprotocol.io/introduction) is
|
||||
an open protocol for connecting Large Language Models (LLMs) to data sources
|
||||
like Oracle. This guide covers how to use [MCP Toolbox for Databases][toolbox]
|
||||
to expose your developer assistant tools to an Oracle instance:
|
||||
|
||||
* [Cursor][cursor]
|
||||
* [Windsurf][windsurf] (Codium)
|
||||
* [Visual Studio Code][vscode] (Copilot)
|
||||
* [Cline][cline] (VS Code extension)
|
||||
* [Claude desktop][claudedesktop]
|
||||
* [Claude code][claudecode]
|
||||
* [Gemini CLI][geminicli]
|
||||
* [Gemini Code Assist][geminicodeassist]
|
||||
|
||||
[toolbox]: https://github.com/googleapis/genai-toolbox
|
||||
[cursor]: #configure-your-mcp-client
|
||||
[windsurf]: #configure-your-mcp-client
|
||||
[vscode]: #configure-your-mcp-client
|
||||
[cline]: #configure-your-mcp-client
|
||||
[claudedesktop]: #configure-your-mcp-client
|
||||
[claudecode]: #configure-your-mcp-client
|
||||
[geminicli]: #configure-your-mcp-client
|
||||
[geminicodeassist]: #configure-your-mcp-client
|
||||
|
||||
## Set up the database
|
||||
|
||||
1. Create or select an Oracle instance.
|
||||
|
||||
1. Create or reuse a database user and have the username and password ready.
|
||||
|
||||
## Install MCP Toolbox
|
||||
|
||||
1. Download the latest version of Toolbox as a binary. Select the [correct
|
||||
binary](https://github.com/googleapis/genai-toolbox/releases) corresponding
|
||||
to your OS and CPU architecture. You are required to use Toolbox version
|
||||
V0.26.0+:
|
||||
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
1. Make the binary executable:
|
||||
|
||||
```bash
|
||||
chmod +x toolbox
|
||||
```
|
||||
|
||||
1. Verify the installation:
|
||||
|
||||
```bash
|
||||
./toolbox --version
|
||||
```
|
||||
|
||||
## Configure your MCP Client
|
||||
|
||||
{{< tabpane text=true >}}
|
||||
{{% tab header="Claude code" lang="en" %}}
|
||||
|
||||
1. Install [Claude
|
||||
Code](https://docs.anthropic.com/en/docs/agents-and-tools/claude-code/overview).
|
||||
1. Create a `.mcp.json` file in your project root if it doesn't exist.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": "",
|
||||
"ORACLE_WALLET_LOCATION": "",
|
||||
"ORACLE_USE_OCI": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
1. Restart Claude code to apply the new configuration.
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Claude desktop" lang="en" %}}
|
||||
|
||||
1. Open Claude desktop and navigate to Settings.
|
||||
1. Under the Developer tab, tap Edit Config to open the configuration file.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": "",
|
||||
"ORACLE_WALLET_LOCATION": "",
|
||||
"ORACLE_USE_OCI": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
1. Restart Claude desktop.
|
||||
1. From the new chat screen, you should see a hammer (MCP) icon appear with the
|
||||
new MCP server available.
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Cline" lang="en" %}}
|
||||
|
||||
1. Open the Cline extension in VS Code and tap
|
||||
the **MCP Servers** icon.
|
||||
1. Tap Configure MCP Servers to open the configuration file.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": "",
|
||||
"ORACLE_WALLET_LOCATION": "",
|
||||
"ORACLE_USE_OCI": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
1. You should see a green active status after the server is successfully
|
||||
connected.
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Cursor" lang="en" %}}
|
||||
|
||||
1. Create a `.cursor` directory in your project root if it doesn't exist.
|
||||
1. Create a `.cursor/mcp.json` file if it doesn't exist and open it.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": "",
|
||||
"ORACLE_WALLET_LOCATION": "",
|
||||
"ORACLE_USE_OCI": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
1. Cursor and navigate to **Settings > Cursor
|
||||
Settings > MCP**. You should see a green active status after the server is
|
||||
successfully connected.
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Visual Studio Code (Copilot)" lang="en" %}}
|
||||
|
||||
1. Open VS Code and
|
||||
create a `.vscode` directory in your project root if it doesn't exist.
|
||||
1. Create a `.vscode/mcp.json` file if it doesn't exist and open it.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"servers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracle","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Windsurf" lang="en" %}}
|
||||
|
||||
1. Open Windsurf and navigate to the
|
||||
Cascade assistant.
|
||||
1. Tap on the hammer (MCP) icon, then Configure to open the configuration file.
|
||||
1. Add the following configuration, replace the environment variables with your
|
||||
values, and save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": "",
|
||||
"ORACLE_WALLET": "",
|
||||
"ORACLE_WALLET_PASSWORD": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Gemini CLI" lang="en" %}}
|
||||
|
||||
1. Install the Gemini CLI.
|
||||
1. In your working directory, create a folder named `.gemini`. Within it, create a `settings.json` file.
|
||||
1. Add the following configuration, replace the environment variables with your values, and then save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
|
||||
{{% tab header="Gemini Code Assist" lang="en" %}}
|
||||
|
||||
1. Install the Gemini Code Assist extension in Visual Studio Code.
|
||||
1. Enable Agent Mode in Gemini Code Assist chat.
|
||||
1. In your working directory, create a folder named `.gemini`. Within it, create a `settings.json` file.
|
||||
1. Add the following configuration, replace the environment variables with your values, and then save:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"oracle": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","oracledb","--stdio"],
|
||||
"env": {
|
||||
"ORACLE_HOST": "",
|
||||
"ORACLE_PORT": "1521",
|
||||
"ORACLE_SERVICE": "",
|
||||
"ORACLE_USER": "",
|
||||
"ORACLE_PASSWORD": ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
{{% /tab %}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
## Use Tools
|
||||
|
||||
Your AI tool is now connected to Oracle using MCP. Try asking your AI
|
||||
assistant to list tables, create a table, or define and execute other SQL
|
||||
statements.
|
||||
|
||||
The following tools are available to the LLM:
|
||||
|
||||
1. **list_tables**: lists tables and descriptions
|
||||
1. **execute_sql**: execute any SQL statement
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
will adapt to the tools available, so this shouldn't affect most users.
|
||||
{{< /notice >}}
|
||||
@@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/docs/overview).
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance:
|
||||
<!-- {x-release-please-start-version} -->
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="linux/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/linux/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/linux/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/arm64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/arm64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/arm64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="darwin/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/darwin/amd64/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/darwin/amd64/toolbox
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="windows/amd64" lang="bash" >}}
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/windows/amd64/toolbox.exe
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/windows/amd64/toolbox.exe
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -692,6 +692,34 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `execute_cypher`: Executes a Cypher query.
|
||||
* `get_schema`: Retrieves the schema of the Neo4j database.
|
||||
|
||||
|
||||
## Oracle
|
||||
|
||||
* `--prebuilt` value: `oracle`
|
||||
* **Environment Variables:**
|
||||
* `ORACLE_HOST`: The hostname or IP address of the Oracle server.
|
||||
* `ORACLE_PORT`: The port number for the Oracle server (Default: 1521).
|
||||
* `ORACLE_CONNECTION_STRING`: The
|
||||
* `ORACLE_USER`: The username for the Oracle DB instance.
|
||||
* `ORACLE_PASSWORD`: The password for the Oracle DB instance
|
||||
* `ORACLE_WALLET`: The path to Oracle DB Wallet file for the databases that support this authentication type
|
||||
* `ORACLE_USE_OCI`: true or false, The flag if the Oracle Database is deployed in cloud deployment. Default is false.
|
||||
* **Permissions:**
|
||||
* Database-level permissions (e.g., `SELECT`, `INSERT`) are required to execute queries.
|
||||
* **Tools:**
|
||||
* `execute_sql`: Executes a SQL query.
|
||||
* `list_tables`: Lists tables in the database.
|
||||
* `list_active_sessions`: Lists active database sessions.
|
||||
* `get_query_plan`: Gets the execution plan for a SQL statement.
|
||||
* `list_top_sql_by_resource`: Lists top SQL statements by resource usage.
|
||||
* `list_tablespace_usage`: Lists tablespace usage.
|
||||
* `list_invalid_objects`: Lists invalid objects.
|
||||
* `list_active_sessions`: Lists active database sessions.
|
||||
* `get_query_plan`: Gets the execution plan for a SQL statement.
|
||||
* `list_top_sql_by_resource`: Lists top SQL statements by resource usage.
|
||||
* `list_tablespace_usage`: Lists tablespace usage.
|
||||
* `list_invalid_objects`: Lists invalid objects.
|
||||
|
||||
## Google Cloud Healthcare API
|
||||
* `--prebuilt` value: `cloud-healthcare`
|
||||
* **Environment Variables:**
|
||||
|
||||
242
docs/en/resources/sources/cockroachdb.md
Normal file
242
docs/en/resources/sources/cockroachdb.md
Normal file
@@ -0,0 +1,242 @@
|
||||
---
|
||||
title: "CockroachDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
CockroachDB is a distributed SQL database built for cloud applications.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[CockroachDB][crdb-docs] is a distributed SQL database designed for cloud-native applications. It provides strong consistency, horizontal scalability, and built-in resilience with automatic failover and recovery. CockroachDB uses the PostgreSQL wire protocol, making it compatible with many PostgreSQL tools and drivers while providing unique features like multi-region deployments and distributed transactions.
|
||||
|
||||
**Minimum Version:** CockroachDB v25.1 or later is recommended for full tool compatibility.
|
||||
|
||||
[crdb-docs]: https://www.cockroachlabs.com/docs/
|
||||
|
||||
## Available Tools
|
||||
|
||||
- [`cockroachdb-sql`](../tools/cockroachdb/cockroachdb-sql.md)
|
||||
Execute SQL queries as prepared statements in CockroachDB (alias for execute-sql).
|
||||
|
||||
- [`cockroachdb-execute-sql`](../tools/cockroachdb/cockroachdb-execute-sql.md)
|
||||
Run parameterized SQL statements in CockroachDB.
|
||||
|
||||
- [`cockroachdb-list-schemas`](../tools/cockroachdb/cockroachdb-list-schemas.md)
|
||||
List schemas in a CockroachDB database.
|
||||
|
||||
- [`cockroachdb-list-tables`](../tools/cockroachdb/cockroachdb-list-tables.md)
|
||||
List tables in a CockroachDB database.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
This source uses standard authentication. You will need to [create a CockroachDB user][crdb-users] to login to the database with. For CockroachDB Cloud deployments, SSL/TLS is required.
|
||||
|
||||
[crdb-users]: https://www.cockroachlabs.com/docs/stable/create-user.html
|
||||
|
||||
### SSL/TLS Configuration
|
||||
|
||||
CockroachDB Cloud clusters require SSL/TLS connections. Use the `queryParams` section to configure SSL settings:
|
||||
|
||||
- **For CockroachDB Cloud**: Use `sslmode: require` at minimum
|
||||
- **For self-hosted with certificates**: Use `sslmode: verify-full` with certificate paths
|
||||
- **For local development only**: Use `sslmode: disable` (not recommended for production)
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
type: cockroachdb
|
||||
host: your-cluster.cockroachlabs.cloud
|
||||
port: "26257"
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: defaultdb
|
||||
maxRetries: 5
|
||||
retryBaseDelay: 500ms
|
||||
queryParams:
|
||||
sslmode: require
|
||||
application_name: my-app
|
||||
|
||||
# MCP Security Settings (recommended for production)
|
||||
readOnlyMode: true # Read-only by default (MCP best practice)
|
||||
enableWriteMode: false # Set to true to allow write operations
|
||||
maxRowLimit: 1000 # Limit query results
|
||||
queryTimeoutSec: 30 # Prevent long-running queries
|
||||
enableTelemetry: true # Enable observability
|
||||
telemetryVerbose: false # Set true for detailed logs
|
||||
clusterID: "my-cluster" # Optional identifier
|
||||
|
||||
tools:
|
||||
list_expenses:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: List all expenses
|
||||
statement: SELECT id, description, amount, category FROM expenses WHERE user_id = $1
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's ID
|
||||
|
||||
describe_expenses:
|
||||
type: cockroachdb-describe-table
|
||||
source: my_cockroachdb
|
||||
description: Describe the expenses table schema
|
||||
|
||||
list_expenses_indexes:
|
||||
type: cockroachdb-list-indexes
|
||||
source: my_cockroachdb
|
||||
description: List indexes on the expenses table
|
||||
```
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
### Required Parameters
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `type` | string | Must be `cockroachdb` |
|
||||
| `host` | string | The hostname or IP address of the CockroachDB cluster |
|
||||
| `port` | string | The port number (typically "26257") |
|
||||
| `user` | string | The database user name |
|
||||
| `database` | string | The database name to connect to |
|
||||
|
||||
### Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `password` | string | "" | The database password (can be empty for certificate-based auth) |
|
||||
| `maxRetries` | integer | 5 | Maximum number of connection retry attempts |
|
||||
| `retryBaseDelay` | string | "500ms" | Base delay between retry attempts (exponential backoff) |
|
||||
| `queryParams` | map | {} | Additional connection parameters (e.g., SSL configuration) |
|
||||
|
||||
### MCP Security Parameters
|
||||
|
||||
CockroachDB integration includes security features following the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) specification:
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `readOnlyMode` | boolean | true | Enables read-only mode by default (MCP requirement) |
|
||||
| `enableWriteMode` | boolean | false | Explicitly enable write operations (INSERT/UPDATE/DELETE/CREATE/DROP) |
|
||||
| `maxRowLimit` | integer | 1000 | Maximum rows returned per SELECT query (auto-adds LIMIT clause) |
|
||||
| `queryTimeoutSec` | integer | 30 | Query timeout in seconds to prevent long-running queries |
|
||||
| `enableTelemetry` | boolean | true | Enable structured logging of tool invocations |
|
||||
| `telemetryVerbose` | boolean | false | Enable detailed JSON telemetry output |
|
||||
| `clusterID` | string | "" | Optional cluster identifier for telemetry |
|
||||
|
||||
### Query Parameters
|
||||
|
||||
Common query parameters for CockroachDB connections:
|
||||
|
||||
| Parameter | Values | Description |
|
||||
|-----------|--------|-------------|
|
||||
| `sslmode` | `disable`, `require`, `verify-ca`, `verify-full` | SSL/TLS mode (CockroachDB Cloud requires `require` or higher) |
|
||||
| `sslrootcert` | file path | Path to root certificate for SSL verification |
|
||||
| `sslcert` | file path | Path to client certificate |
|
||||
| `sslkey` | file path | Path to client key |
|
||||
| `application_name` | string | Application name for connection tracking |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security and MCP Compliance
|
||||
|
||||
**Read-Only by Default**: The integration follows MCP best practices by defaulting to read-only mode. This prevents accidental data modifications:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
readOnlyMode: true # Default behavior
|
||||
enableWriteMode: false # Explicit write opt-in required
|
||||
```
|
||||
|
||||
To enable write operations:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
readOnlyMode: false # Disable read-only protection
|
||||
enableWriteMode: true # Explicitly allow writes
|
||||
```
|
||||
|
||||
**Query Limits**: Automatic row limits prevent excessive data retrieval:
|
||||
- SELECT queries automatically get `LIMIT 1000` appended (configurable via `maxRowLimit`)
|
||||
- Queries are terminated after 30 seconds (configurable via `queryTimeoutSec`)
|
||||
|
||||
**Observability**: Structured telemetry provides visibility into tool usage:
|
||||
- Tool invocations are logged with status, latency, and row counts
|
||||
- SQL queries are redacted to protect sensitive values
|
||||
- Set `telemetryVerbose: true` for detailed JSON logs
|
||||
|
||||
### Use UUID Primary Keys
|
||||
|
||||
CockroachDB performs best with UUID primary keys rather than sequential integers to avoid transaction hotspots:
|
||||
|
||||
```sql
|
||||
CREATE TABLE expenses (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
description TEXT,
|
||||
amount DECIMAL(10,2)
|
||||
);
|
||||
```
|
||||
|
||||
### Automatic Transaction Retry
|
||||
|
||||
This source uses the official `cockroach-go/v2` library which provides automatic transaction retry for serialization conflicts. For write operations requiring explicit transaction control, tools can use the `ExecuteTxWithRetry` method.
|
||||
|
||||
### Multi-Region Deployments
|
||||
|
||||
CockroachDB supports multi-region deployments with automatic data distribution. Configure your cluster's regions and survival goals separately from the Toolbox configuration. The source will connect to any node in the cluster.
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
The source maintains a connection pool to the CockroachDB cluster. The pool automatically handles:
|
||||
- Load balancing across cluster nodes
|
||||
- Connection retry with exponential backoff
|
||||
- Health checking of connections
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### SSL/TLS Errors
|
||||
|
||||
If you encounter "server requires encryption" errors:
|
||||
|
||||
1. For CockroachDB Cloud, ensure `sslmode` is set to `require` or higher:
|
||||
```yaml
|
||||
queryParams:
|
||||
sslmode: require
|
||||
```
|
||||
|
||||
2. For certificate verification, download your cluster's root certificate and configure:
|
||||
```yaml
|
||||
queryParams:
|
||||
sslmode: verify-full
|
||||
sslrootcert: /path/to/ca.crt
|
||||
```
|
||||
|
||||
### Connection Timeouts
|
||||
|
||||
If experiencing connection timeouts:
|
||||
|
||||
1. Check network connectivity to the CockroachDB cluster
|
||||
2. Verify firewall rules allow connections on port 26257
|
||||
3. For CockroachDB Cloud, ensure IP allowlisting is configured
|
||||
4. Increase `maxRetries` or `retryBaseDelay` if needed
|
||||
|
||||
### Transaction Retry Errors
|
||||
|
||||
CockroachDB may encounter serializable transaction conflicts. The integration automatically handles these retries using the cockroach-go library. If you see retry-related errors, check:
|
||||
|
||||
1. Database load and contention
|
||||
2. Query patterns that might cause conflicts
|
||||
3. Consider using `SELECT FOR UPDATE` for explicit locking
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CockroachDB Documentation](https://www.cockroachlabs.com/docs/)
|
||||
- [CockroachDB Best Practices](https://www.cockroachlabs.com/docs/stable/performance-best-practices-overview.html)
|
||||
- [Multi-Region Capabilities](https://www.cockroachlabs.com/docs/stable/multiregion-overview.html)
|
||||
- [Connection Parameters](https://www.cockroachlabs.com/docs/stable/connection-parameters.html)
|
||||
@@ -87,28 +87,41 @@ using a TNS (Transparent Network Substrate) alias.
|
||||
|
||||
## Examples
|
||||
|
||||
This example demonstrates the four connection methods you could choose from:
|
||||
### 1. Basic Connection (Host, Port, and Service Name)
|
||||
|
||||
```yaml
|
||||
kind: sources
|
||||
name: my-oracle-source
|
||||
type: oracle
|
||||
|
||||
# --- Choose one connection method ---
|
||||
# 1. Host, Port, and Service Name
|
||||
sources:
|
||||
my-oracle-source:
|
||||
kind: oracle
|
||||
host: 127.0.0.1
|
||||
port: 1521
|
||||
serviceName: XEPDB1
|
||||
|
||||
# 2. Direct Connection String
|
||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||
|
||||
# 3. TNS Alias (requires tnsnames.ora)
|
||||
tnsAlias: "MY_DB_ALIAS"
|
||||
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
|
||||
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
```
|
||||
|
||||
### 2. Direct Connection String
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-oracle-source:
|
||||
kind: oracle
|
||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
```
|
||||
|
||||
### 3. TNS Alias (requires tnsnames.ora)
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-oracle-source:
|
||||
kind: oracle
|
||||
tnsAlias: "MY_DB_ALIAS"
|
||||
tnsAdmin: "/opt/oracle/network/admin" # Optional: overrides TNS_ADMIN env var
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
useOCI: true # tnsAlias requires useOCI to be true
|
||||
|
||||
# Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client)
|
||||
```
|
||||
@@ -168,3 +181,4 @@ instead of hardcoding your secrets into the configuration file.
|
||||
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
||||
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
||||
| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). |
|
||||
| walletLocation | string | false | Path to the directory containing the wallet files for the pure Go driver (`useOCI: false`). |
|
||||
|
||||
273
docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md
Normal file
273
docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md
Normal file
@@ -0,0 +1,273 @@
|
||||
---
|
||||
title: "cockroachdb-execute-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Execute ad-hoc SQL statements against a CockroachDB database.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `cockroachdb-execute-sql` tool executes ad-hoc SQL statements against a CockroachDB database. This tool is designed for interactive workflows where the SQL query is provided dynamically at runtime, making it ideal for developer assistance and exploratory data analysis.
|
||||
|
||||
The tool takes a single `sql` parameter containing the SQL statement to execute and returns the query results.
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. For production use cases with predefined queries, use [cockroachdb-sql](./cockroachdb-sql.md) instead.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
type: cockroachdb
|
||||
host: your-cluster.cockroachlabs.cloud
|
||||
port: "26257"
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: defaultdb
|
||||
queryParams:
|
||||
sslmode: require
|
||||
|
||||
tools:
|
||||
execute_sql:
|
||||
type: cockroachdb-execute-sql
|
||||
source: my_cockroachdb
|
||||
description: Execute any SQL statement against the CockroachDB database
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Simple SELECT Query
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT * FROM users LIMIT 10"
|
||||
}
|
||||
```
|
||||
|
||||
### Query with Aggregations
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT category, COUNT(*) as count, SUM(amount) as total FROM expenses GROUP BY category ORDER BY total DESC"
|
||||
}
|
||||
```
|
||||
|
||||
### Database Introspection
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SHOW TABLES"
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SHOW COLUMNS FROM expenses"
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Region Information
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SHOW REGIONS FROM DATABASE defaultdb"
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SHOW ZONE CONFIGURATIONS"
|
||||
}
|
||||
```
|
||||
|
||||
## CockroachDB-Specific Features
|
||||
|
||||
### Check Cluster Version
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT version()"
|
||||
}
|
||||
```
|
||||
|
||||
### View Node Status
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT node_id, address, locality, is_live FROM crdb_internal.gossip_nodes"
|
||||
}
|
||||
```
|
||||
|
||||
### Check Replication Status
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SELECT range_id, start_key, end_key, replicas, lease_holder FROM crdb_internal.ranges LIMIT 10"
|
||||
}
|
||||
```
|
||||
|
||||
### View Table Regions
|
||||
|
||||
```json
|
||||
{
|
||||
"sql": "SHOW REGIONS FROM TABLE expenses"
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `type` | string | Must be `cockroachdb-execute-sql` |
|
||||
| `source` | string | Name of the CockroachDB source to use |
|
||||
| `description` | string | Human-readable description for the LLM |
|
||||
|
||||
### Optional Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `authRequired` | array | List of authentication services required |
|
||||
|
||||
## Parameters
|
||||
|
||||
The tool accepts a single runtime parameter:
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `sql` | string | The SQL statement to execute |
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use for Exploration, Not Production
|
||||
|
||||
This tool is ideal for:
|
||||
- Interactive database exploration
|
||||
- Ad-hoc analysis and reporting
|
||||
- Debugging and troubleshooting
|
||||
- Schema inspection
|
||||
|
||||
For production use cases, use [cockroachdb-sql](./cockroachdb-sql.md) with parameterized queries.
|
||||
|
||||
### Be Cautious with Data Modification
|
||||
|
||||
While this tool can execute any SQL statement, be careful with:
|
||||
- `INSERT`, `UPDATE`, `DELETE` statements
|
||||
- `DROP` or `ALTER` statements
|
||||
- Schema changes in production
|
||||
|
||||
### Use LIMIT for Large Results
|
||||
|
||||
Always use `LIMIT` clauses when exploring data:
|
||||
|
||||
```sql
|
||||
SELECT * FROM large_table LIMIT 100
|
||||
```
|
||||
|
||||
### Leverage CockroachDB's SQL Extensions
|
||||
|
||||
CockroachDB supports PostgreSQL syntax plus extensions:
|
||||
|
||||
```sql
|
||||
-- Show database survival goal
|
||||
SHOW SURVIVAL GOAL FROM DATABASE defaultdb;
|
||||
|
||||
-- View zone configurations
|
||||
SHOW ZONE CONFIGURATION FOR TABLE expenses;
|
||||
|
||||
-- Check table localities
|
||||
SHOW CREATE TABLE expenses;
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool will return descriptive errors for:
|
||||
- **Syntax errors**: Invalid SQL syntax
|
||||
- **Permission errors**: Insufficient user privileges
|
||||
- **Connection errors**: Network or authentication issues
|
||||
- **Runtime errors**: Constraint violations, type mismatches, etc.
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### SQL Injection Risk
|
||||
|
||||
Since this tool executes arbitrary SQL, it should only be used with:
|
||||
- Trusted users in interactive sessions
|
||||
- Human-in-the-loop workflows
|
||||
- Development and testing environments
|
||||
|
||||
Never expose this tool directly to end users without proper authorization controls.
|
||||
|
||||
### Use Authentication
|
||||
|
||||
Configure the `authRequired` field to restrict access:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
execute_sql:
|
||||
type: cockroachdb-execute-sql
|
||||
source: my_cockroachdb
|
||||
description: Execute SQL statements
|
||||
authRequired:
|
||||
- my-auth-service
|
||||
```
|
||||
|
||||
### Read-Only Users
|
||||
|
||||
For safer exploration, create read-only database users:
|
||||
|
||||
```sql
|
||||
CREATE USER readonly_user;
|
||||
GRANT SELECT ON DATABASE defaultdb TO readonly_user;
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Database Administration
|
||||
|
||||
```sql
|
||||
-- View database size
|
||||
SELECT
|
||||
table_name,
|
||||
pg_size_pretty(pg_total_relation_size(table_name::regclass)) AS size
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
ORDER BY pg_total_relation_size(table_name::regclass) DESC;
|
||||
```
|
||||
|
||||
### Performance Analysis
|
||||
|
||||
```sql
|
||||
-- Find slow queries
|
||||
SELECT query, count, mean_latency
|
||||
FROM crdb_internal.statement_statistics
|
||||
WHERE mean_latency > INTERVAL '1 second'
|
||||
ORDER BY mean_latency DESC
|
||||
LIMIT 10;
|
||||
```
|
||||
|
||||
### Data Quality Checks
|
||||
|
||||
```sql
|
||||
-- Find NULL values
|
||||
SELECT COUNT(*) as null_count
|
||||
FROM expenses
|
||||
WHERE description IS NULL OR amount IS NULL;
|
||||
|
||||
-- Find duplicates
|
||||
SELECT user_id, email, COUNT(*) as count
|
||||
FROM users
|
||||
GROUP BY user_id, email
|
||||
HAVING COUNT(*) > 1;
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [cockroachdb-sql](./cockroachdb-sql.md) - For parameterized, production-ready queries
|
||||
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||
- [CockroachDB SQL Reference](https://www.cockroachlabs.com/docs/stable/sql-statements.html) - Official SQL documentation
|
||||
305
docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md
Normal file
305
docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md
Normal file
@@ -0,0 +1,305 @@
|
||||
---
|
||||
title: "cockroachdb-list-schemas"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
List schemas in a CockroachDB database.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cockroachdb-list-schemas` tool retrieves a list of schemas (namespaces) in a CockroachDB database. Schemas are used to organize database objects such as tables, views, and functions into logical groups.
|
||||
|
||||
This tool is useful for:
|
||||
- Understanding database organization
|
||||
- Discovering available schemas
|
||||
- Multi-tenant application analysis
|
||||
- Schema-level access control planning
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
type: cockroachdb
|
||||
host: your-cluster.cockroachlabs.cloud
|
||||
port: "26257"
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: defaultdb
|
||||
queryParams:
|
||||
sslmode: require
|
||||
|
||||
tools:
|
||||
list_schemas:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: List all schemas in the database
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `type` | string | Must be `cockroachdb-list-schemas` |
|
||||
| `source` | string | Name of the CockroachDB source to use |
|
||||
| `description` | string | Human-readable description for the LLM |
|
||||
|
||||
### Optional Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `authRequired` | array | List of authentication services required |
|
||||
|
||||
## Output Structure
|
||||
|
||||
The tool returns a list of schemas with the following information:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"catalog_name": "defaultdb",
|
||||
"schema_name": "public",
|
||||
"is_user_defined": true
|
||||
},
|
||||
{
|
||||
"catalog_name": "defaultdb",
|
||||
"schema_name": "analytics",
|
||||
"is_user_defined": true
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `catalog_name` | string | The database (catalog) name |
|
||||
| `schema_name` | string | The schema name |
|
||||
| `is_user_defined` | boolean | Whether this is a user-created schema (excludes system schemas) |
|
||||
|
||||
## Usage Example
|
||||
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
No parameters are required. The tool automatically lists all user-defined schemas.
|
||||
|
||||
## Default Schemas
|
||||
|
||||
CockroachDB includes several standard schemas:
|
||||
|
||||
- **`public`**: The default schema for user objects
|
||||
- **`pg_catalog`**: PostgreSQL system catalog (excluded from results)
|
||||
- **`information_schema`**: SQL standard metadata views (excluded from results)
|
||||
- **`crdb_internal`**: CockroachDB internal metadata (excluded from results)
|
||||
- **`pg_extension`**: PostgreSQL extension objects (excluded from results)
|
||||
|
||||
The tool filters out system schemas and only returns user-defined schemas.
|
||||
|
||||
## Schema Management in CockroachDB
|
||||
|
||||
### Creating Schemas
|
||||
|
||||
```sql
|
||||
CREATE SCHEMA analytics;
|
||||
```
|
||||
|
||||
### Using Schemas
|
||||
|
||||
```sql
|
||||
-- Create table in specific schema
|
||||
CREATE TABLE analytics.revenue (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
amount DECIMAL(10,2),
|
||||
date DATE
|
||||
);
|
||||
|
||||
-- Query from specific schema
|
||||
SELECT * FROM analytics.revenue;
|
||||
```
|
||||
|
||||
### Schema Search Path
|
||||
|
||||
The search path determines which schemas are searched for unqualified object names:
|
||||
|
||||
```sql
|
||||
-- Show current search path
|
||||
SHOW search_path;
|
||||
|
||||
-- Set search path
|
||||
SET search_path = analytics, public;
|
||||
```
|
||||
|
||||
## Multi-Tenant Applications
|
||||
|
||||
Schemas are commonly used for multi-tenant applications:
|
||||
|
||||
```sql
|
||||
-- Create schema per tenant
|
||||
CREATE SCHEMA tenant_acme;
|
||||
CREATE SCHEMA tenant_globex;
|
||||
|
||||
-- Create same table structure in each schema
|
||||
CREATE TABLE tenant_acme.orders (...);
|
||||
CREATE TABLE tenant_globex.orders (...);
|
||||
```
|
||||
|
||||
The `cockroachdb-list-schemas` tool helps discover all tenant schemas:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_tenants:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
List all tenant schemas in the database.
|
||||
Each schema represents a separate tenant's data namespace.
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use Schemas for Organization
|
||||
|
||||
Group related tables into schemas:
|
||||
|
||||
```sql
|
||||
CREATE SCHEMA sales;
|
||||
CREATE SCHEMA inventory;
|
||||
CREATE SCHEMA hr;
|
||||
|
||||
CREATE TABLE sales.orders (...);
|
||||
CREATE TABLE inventory.products (...);
|
||||
CREATE TABLE hr.employees (...);
|
||||
```
|
||||
|
||||
### Schema Naming Conventions
|
||||
|
||||
Use clear, descriptive schema names:
|
||||
- Lowercase names
|
||||
- Use underscores for multi-word names
|
||||
- Avoid reserved keywords
|
||||
- Use prefixes for grouped schemas (e.g., `tenant_`, `app_`)
|
||||
|
||||
### Schema-Level Permissions
|
||||
|
||||
Schemas enable fine-grained access control:
|
||||
|
||||
```sql
|
||||
-- Grant access to specific schema
|
||||
GRANT USAGE ON SCHEMA analytics TO analyst_role;
|
||||
GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analyst_role;
|
||||
|
||||
-- Revoke access
|
||||
REVOKE ALL ON SCHEMA hr FROM public;
|
||||
```
|
||||
|
||||
## Integration with Other Tools
|
||||
|
||||
### Combined with List Tables
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_schemas:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: List all schemas first
|
||||
|
||||
list_tables:
|
||||
type: cockroachdb-list-tables
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
List tables in the database.
|
||||
Use list_schemas first to understand schema organization.
|
||||
```
|
||||
|
||||
### Schema Discovery Workflow
|
||||
|
||||
1. Call `cockroachdb-list-schemas` to discover schemas
|
||||
2. Call `cockroachdb-list-tables` to see tables in each schema
|
||||
3. Generate queries using fully qualified names: `schema.table`
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Discover Database Structure
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
discover_schemas:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
Discover how the database is organized into schemas.
|
||||
Use this to understand the logical grouping of tables.
|
||||
```
|
||||
|
||||
### Multi-Tenant Analysis
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_tenant_schemas:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
List all tenant schemas (each tenant has their own schema).
|
||||
Schema names follow the pattern: tenant_<company_name>
|
||||
```
|
||||
|
||||
### Schema Migration Planning
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
audit_schemas:
|
||||
type: cockroachdb-list-schemas
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
Audit existing schemas before migration.
|
||||
Identifies all schemas that need to be migrated.
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool handles common errors:
|
||||
- **Connection errors**: Returns connection failure details
|
||||
- **Permission errors**: Returns error if user lacks USAGE privilege
|
||||
- **Empty results**: Returns empty array if no user schemas exist
|
||||
|
||||
## Permissions Required
|
||||
|
||||
To list schemas, the user needs:
|
||||
- `CONNECT` privilege on the database
|
||||
- No specific schema privileges required for listing
|
||||
|
||||
To query objects within schemas, the user needs:
|
||||
- `USAGE` privilege on the schema
|
||||
- Appropriate object privileges (SELECT, INSERT, etc.)
|
||||
|
||||
## CockroachDB-Specific Features
|
||||
|
||||
### System Schemas
|
||||
|
||||
CockroachDB includes PostgreSQL-compatible system schemas plus CockroachDB-specific ones:
|
||||
|
||||
- `crdb_internal.*`: CockroachDB internal metadata and statistics
|
||||
- `pg_catalog.*`: PostgreSQL system catalog
|
||||
- `information_schema.*`: SQL standard information schema
|
||||
|
||||
These are automatically filtered from the results.
|
||||
|
||||
### User-Defined Flag
|
||||
|
||||
The `is_user_defined` field helps distinguish:
|
||||
- `true`: User-created schemas
|
||||
- `false`: System schemas (already filtered out)
|
||||
|
||||
## See Also
|
||||
|
||||
- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries
|
||||
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL
|
||||
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||
- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Official documentation
|
||||
344
docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md
Normal file
344
docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md
Normal file
@@ -0,0 +1,344 @@
|
||||
---
|
||||
title: "cockroachdb-list-tables"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
List tables in a CockroachDB database with schema details.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cockroachdb-list-tables` tool retrieves a list of tables from a CockroachDB database. It provides detailed information about table structure, including columns, constraints, indexes, and foreign key relationships.
|
||||
|
||||
This tool is useful for:
|
||||
- Database schema discovery
|
||||
- Understanding table relationships
|
||||
- Generating context for AI-powered database queries
|
||||
- Documentation and analysis
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
type: cockroachdb
|
||||
host: your-cluster.cockroachlabs.cloud
|
||||
port: "26257"
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: defaultdb
|
||||
queryParams:
|
||||
sslmode: require
|
||||
|
||||
tools:
|
||||
list_all_tables:
|
||||
type: cockroachdb-list-tables
|
||||
source: my_cockroachdb
|
||||
description: List all user tables in the database with their structure
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `type` | string | Must be `cockroachdb-list-tables` |
|
||||
| `source` | string | Name of the CockroachDB source to use |
|
||||
| `description` | string | Human-readable description for the LLM |
|
||||
|
||||
### Optional Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `authRequired` | array | List of authentication services required |
|
||||
|
||||
## Parameters
|
||||
|
||||
The tool accepts optional runtime parameters:
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `table_names` | array | all tables | List of specific table names to retrieve |
|
||||
| `output_format` | string | "detailed" | Output format: "simple" or "detailed" |
|
||||
|
||||
## Output Formats
|
||||
|
||||
### Simple Format
|
||||
|
||||
Returns basic table information:
|
||||
- Table name
|
||||
- Row count estimate
|
||||
- Size information
|
||||
|
||||
```json
|
||||
{
|
||||
"table_names": ["users"],
|
||||
"output_format": "simple"
|
||||
}
|
||||
```
|
||||
|
||||
### Detailed Format (Default)
|
||||
|
||||
Returns comprehensive table information:
|
||||
- Table name and schema
|
||||
- All columns with types and constraints
|
||||
- Primary keys
|
||||
- Foreign keys and relationships
|
||||
- Indexes
|
||||
- Check constraints
|
||||
- Table size and row counts
|
||||
|
||||
```json
|
||||
{
|
||||
"table_names": ["users", "orders"],
|
||||
"output_format": "detailed"
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### List All Tables
|
||||
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
### List Specific Tables
|
||||
|
||||
```json
|
||||
{
|
||||
"table_names": ["users", "orders", "expenses"]
|
||||
}
|
||||
```
|
||||
|
||||
### Simple Output
|
||||
|
||||
```json
|
||||
{
|
||||
"output_format": "simple"
|
||||
}
|
||||
```
|
||||
|
||||
## Output Structure
|
||||
|
||||
### Simple Format Output
|
||||
|
||||
```json
|
||||
{
|
||||
"table_name": "users",
|
||||
"estimated_rows": 1000,
|
||||
"size": "128 KB"
|
||||
}
|
||||
```
|
||||
|
||||
### Detailed Format Output
|
||||
|
||||
```json
|
||||
{
|
||||
"table_name": "users",
|
||||
"schema": "public",
|
||||
"columns": [
|
||||
{
|
||||
"name": "id",
|
||||
"type": "UUID",
|
||||
"nullable": false,
|
||||
"default": "gen_random_uuid()"
|
||||
},
|
||||
{
|
||||
"name": "email",
|
||||
"type": "STRING",
|
||||
"nullable": false,
|
||||
"default": null
|
||||
},
|
||||
{
|
||||
"name": "created_at",
|
||||
"type": "TIMESTAMP",
|
||||
"nullable": false,
|
||||
"default": "now()"
|
||||
}
|
||||
],
|
||||
"primary_key": ["id"],
|
||||
"indexes": [
|
||||
{
|
||||
"name": "users_pkey",
|
||||
"columns": ["id"],
|
||||
"unique": true,
|
||||
"primary": true
|
||||
},
|
||||
{
|
||||
"name": "users_email_idx",
|
||||
"columns": ["email"],
|
||||
"unique": true,
|
||||
"primary": false
|
||||
}
|
||||
],
|
||||
"foreign_keys": [],
|
||||
"constraints": [
|
||||
{
|
||||
"name": "users_email_check",
|
||||
"type": "CHECK",
|
||||
"definition": "email ~* '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}$'"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## CockroachDB-Specific Information
|
||||
|
||||
### UUID Primary Keys
|
||||
|
||||
The tool recognizes CockroachDB's recommended UUID primary key pattern:
|
||||
|
||||
```sql
|
||||
CREATE TABLE users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
...
|
||||
);
|
||||
```
|
||||
|
||||
### Multi-Region Tables
|
||||
|
||||
For multi-region tables, the output includes locality information:
|
||||
|
||||
```json
|
||||
{
|
||||
"table_name": "users",
|
||||
"locality": "REGIONAL BY ROW",
|
||||
"regions": ["us-east-1", "us-west-2", "eu-west-1"]
|
||||
}
|
||||
```
|
||||
|
||||
### Interleaved Tables
|
||||
|
||||
The tool shows parent-child relationships for interleaved tables (legacy feature):
|
||||
|
||||
```json
|
||||
{
|
||||
"table_name": "order_items",
|
||||
"interleaved_in": "orders"
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use for Schema Discovery
|
||||
|
||||
The tool is ideal for helping AI assistants understand your database structure:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
discover_schema:
|
||||
type: cockroachdb-list-tables
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
Use this tool first to understand the database schema before generating queries.
|
||||
It shows all tables, their columns, data types, and relationships.
|
||||
```
|
||||
|
||||
### Filter Large Schemas
|
||||
|
||||
For databases with many tables, specify relevant tables:
|
||||
|
||||
```json
|
||||
{
|
||||
"table_names": ["users", "orders", "products"],
|
||||
"output_format": "detailed"
|
||||
}
|
||||
```
|
||||
|
||||
### Use Simple Format for Overviews
|
||||
|
||||
When you need just table names and sizes:
|
||||
|
||||
```json
|
||||
{
|
||||
"output_format": "simple"
|
||||
}
|
||||
```
|
||||
|
||||
## Excluded Tables
|
||||
|
||||
The tool automatically excludes system tables and schemas:
|
||||
- `pg_catalog.*` - PostgreSQL system catalog
|
||||
- `information_schema.*` - SQL standard information schema
|
||||
- `crdb_internal.*` - CockroachDB internal tables
|
||||
- `pg_extension.*` - PostgreSQL extension tables
|
||||
|
||||
Only user-created tables in the public schema (and other user schemas) are returned.
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool handles common errors:
|
||||
- **Table not found**: Returns empty result for non-existent tables
|
||||
- **Permission errors**: Returns error if user lacks SELECT privileges
|
||||
- **Connection errors**: Returns connection failure details
|
||||
|
||||
## Integration with AI Assistants
|
||||
|
||||
### Prompt Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_tables:
|
||||
type: cockroachdb-list-tables
|
||||
source: my_cockroachdb
|
||||
description: |
|
||||
Lists all tables in the database with detailed schema information.
|
||||
Use this tool to understand:
|
||||
- What tables exist
|
||||
- What columns each table has
|
||||
- Data types and constraints
|
||||
- Relationships between tables (foreign keys)
|
||||
- Available indexes
|
||||
|
||||
Always call this tool before generating SQL queries to ensure
|
||||
you use correct table and column names.
|
||||
```
|
||||
|
||||
## Common Use Cases
|
||||
|
||||
### Generate Context for Queries
|
||||
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
This provides comprehensive schema information that helps AI assistants generate accurate SQL queries.
|
||||
|
||||
### Analyze Table Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"table_names": ["users"],
|
||||
"output_format": "detailed"
|
||||
}
|
||||
```
|
||||
|
||||
Perfect for understanding a specific table's structure, constraints, and relationships.
|
||||
|
||||
### Quick Schema Overview
|
||||
|
||||
```json
|
||||
{
|
||||
"output_format": "simple"
|
||||
}
|
||||
```
|
||||
|
||||
Gets a quick list of tables with basic statistics.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **Simple format** is faster for large databases
|
||||
- **Detailed format** queries system tables extensively
|
||||
- Specifying `table_names` reduces query time
|
||||
- Results are fetched in a single query for efficiency
|
||||
|
||||
## See Also
|
||||
|
||||
- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries
|
||||
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL
|
||||
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||
- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Best practices
|
||||
291
docs/en/resources/tools/cockroachdb/cockroachdb-sql.md
Normal file
291
docs/en/resources/tools/cockroachdb/cockroachdb-sql.md
Normal file
@@ -0,0 +1,291 @@
|
||||
---
|
||||
title: "cockroachdb-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Execute parameterized SQL queries in CockroachDB.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cockroachdb-sql` tool allows you to execute parameterized SQL queries against a CockroachDB database. This tool supports prepared statements with parameter binding, template parameters for dynamic query construction, and automatic transaction retry for resilience against serialization conflicts.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my_cockroachdb:
|
||||
type: cockroachdb
|
||||
host: your-cluster.cockroachlabs.cloud
|
||||
port: "26257"
|
||||
user: myuser
|
||||
password: mypassword
|
||||
database: defaultdb
|
||||
queryParams:
|
||||
sslmode: require
|
||||
|
||||
tools:
|
||||
get_user_orders:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Get all orders for a specific user
|
||||
statement: |
|
||||
SELECT o.id, o.order_date, o.total_amount, o.status
|
||||
FROM orders o
|
||||
WHERE o.user_id = $1
|
||||
ORDER BY o.order_date DESC
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The UUID of the user
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Required Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `type` | string | Must be `cockroachdb-sql` |
|
||||
| `source` | string | Name of the CockroachDB source to use |
|
||||
| `description` | string | Human-readable description of what the tool does |
|
||||
| `statement` | string | The SQL query to execute |
|
||||
|
||||
### Optional Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `parameters` | array | List of parameter definitions for the query |
|
||||
| `templateParameters` | array | List of template parameters for dynamic query construction |
|
||||
| `authRequired` | array | List of authentication services required |
|
||||
|
||||
## Parameters
|
||||
|
||||
Parameters allow you to safely pass values into your SQL queries using prepared statements. CockroachDB uses PostgreSQL-style parameter placeholders: `$1`, `$2`, etc.
|
||||
|
||||
### Parameter Types
|
||||
|
||||
- `string`: Text values
|
||||
- `number`: Numeric values (integers or decimals)
|
||||
- `boolean`: True/false values
|
||||
- `array`: Array of values
|
||||
|
||||
### Example with Multiple Parameters
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
filter_expenses:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Filter expenses by category and date range
|
||||
statement: |
|
||||
SELECT id, description, amount, category, expense_date
|
||||
FROM expenses
|
||||
WHERE user_id = $1
|
||||
AND category = $2
|
||||
AND expense_date >= $3
|
||||
AND expense_date <= $4
|
||||
ORDER BY expense_date DESC
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's UUID
|
||||
- name: category
|
||||
type: string
|
||||
description: Expense category (e.g., "Food", "Transport")
|
||||
- name: start_date
|
||||
type: string
|
||||
description: Start date in YYYY-MM-DD format
|
||||
- name: end_date
|
||||
type: string
|
||||
description: End date in YYYY-MM-DD format
|
||||
```
|
||||
|
||||
## Template Parameters
|
||||
|
||||
Template parameters enable dynamic query construction by replacing placeholders in the SQL statement before parameter binding. This is useful for dynamic table names, column names, or query structure.
|
||||
|
||||
### Example with Template Parameters
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_column_data:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Get data from a specific column
|
||||
statement: |
|
||||
SELECT {{column_name}}
|
||||
FROM {{table_name}}
|
||||
WHERE user_id = $1
|
||||
LIMIT 100
|
||||
templateParameters:
|
||||
- name: table_name
|
||||
type: string
|
||||
description: The table to query
|
||||
- name: column_name
|
||||
type: string
|
||||
description: The column to retrieve
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's UUID
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use UUID Primary Keys
|
||||
|
||||
CockroachDB performs best with UUID primary keys to avoid transaction hotspots:
|
||||
|
||||
```sql
|
||||
CREATE TABLE orders (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL,
|
||||
order_date TIMESTAMP DEFAULT now(),
|
||||
total_amount DECIMAL(10,2)
|
||||
);
|
||||
```
|
||||
|
||||
### Use Indexes for Performance
|
||||
|
||||
Create indexes on frequently queried columns:
|
||||
|
||||
```sql
|
||||
CREATE INDEX idx_orders_user_id ON orders(user_id);
|
||||
CREATE INDEX idx_orders_date ON orders(order_date DESC);
|
||||
```
|
||||
|
||||
### Use JOINs Efficiently
|
||||
|
||||
CockroachDB supports standard SQL JOINs. Keep joins efficient by:
|
||||
- Adding appropriate indexes
|
||||
- Using UUIDs for foreign keys
|
||||
- Limiting result sets with WHERE clauses
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_user_with_orders:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Get user details with their recent orders
|
||||
statement: |
|
||||
SELECT u.name, u.email, o.id as order_id, o.order_date, o.total_amount
|
||||
FROM users u
|
||||
LEFT JOIN orders o ON u.id = o.user_id
|
||||
WHERE u.id = $1
|
||||
ORDER BY o.order_date DESC
|
||||
LIMIT 10
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's UUID
|
||||
```
|
||||
|
||||
### Handle NULL Values
|
||||
|
||||
Use COALESCE or NULL checks when dealing with nullable columns:
|
||||
|
||||
```sql
|
||||
SELECT id, description, COALESCE(notes, 'No notes') as notes
|
||||
FROM expenses
|
||||
WHERE user_id = $1
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool automatically handles:
|
||||
- **Connection errors**: Retried with exponential backoff
|
||||
- **Serialization conflicts**: Automatically retried using cockroach-go library
|
||||
- **Invalid parameters**: Returns descriptive error messages
|
||||
- **SQL syntax errors**: Returns database error details
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Aggregations
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
expense_summary:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Get expense summary by category for a user
|
||||
statement: |
|
||||
SELECT
|
||||
category,
|
||||
COUNT(*) as count,
|
||||
SUM(amount) as total_amount,
|
||||
AVG(amount) as avg_amount
|
||||
FROM expenses
|
||||
WHERE user_id = $1
|
||||
AND expense_date >= $2
|
||||
GROUP BY category
|
||||
ORDER BY total_amount DESC
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's UUID
|
||||
- name: start_date
|
||||
type: string
|
||||
description: Start date in YYYY-MM-DD format
|
||||
```
|
||||
|
||||
### Window Functions
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
running_total:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Get running total of expenses
|
||||
statement: |
|
||||
SELECT
|
||||
expense_date,
|
||||
amount,
|
||||
SUM(amount) OVER (ORDER BY expense_date) as running_total
|
||||
FROM expenses
|
||||
WHERE user_id = $1
|
||||
ORDER BY expense_date
|
||||
parameters:
|
||||
- name: user_id
|
||||
type: string
|
||||
description: The user's UUID
|
||||
```
|
||||
|
||||
### Common Table Expressions (CTEs)
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
top_spenders:
|
||||
type: cockroachdb-sql
|
||||
source: my_cockroachdb
|
||||
description: Find top spending users
|
||||
statement: |
|
||||
WITH user_totals AS (
|
||||
SELECT
|
||||
user_id,
|
||||
SUM(amount) as total_spent
|
||||
FROM expenses
|
||||
WHERE expense_date >= $1
|
||||
GROUP BY user_id
|
||||
)
|
||||
SELECT
|
||||
u.name,
|
||||
u.email,
|
||||
ut.total_spent
|
||||
FROM user_totals ut
|
||||
JOIN users u ON ut.user_id = u.id
|
||||
ORDER BY ut.total_spent DESC
|
||||
LIMIT 10
|
||||
parameters:
|
||||
- name: start_date
|
||||
type: string
|
||||
description: Start date in YYYY-MM-DD format
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - For ad-hoc SQL execution
|
||||
- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database
|
||||
- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas
|
||||
- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference
|
||||
@@ -771,7 +771,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.26.0\" # x-release-please-version\n",
|
||||
"version = \"0.27.0\" # x-release-please-version\n",
|
||||
"! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
export VERSION="0.26.0"
|
||||
export VERSION="0.27.0"
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
@@ -220,7 +220,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"version = \"0.26.0\" # x-release-please-version\n",
|
||||
"version = \"0.27.0\" # x-release-please-version\n",
|
||||
"! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n",
|
||||
"\n",
|
||||
"# Make the binary executable\n",
|
||||
|
||||
@@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server.
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.26.0/$OS/toolbox
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.27.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
|
||||
7
docs/en/samples/oracle/_index.md
Normal file
7
docs/en/samples/oracle/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "OracleDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to get started with Toolbox using Oracle Database.
|
||||
---
|
||||
@@ -16,14 +16,7 @@ This guide demonstrates how to implement these patterns in your Toolbox applicat
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{% tab header="ADK" text=true %}}
|
||||
The following example demonstrates how to use `ToolboxToolset` with ADK's pre and post processing hooks to implement pre and post processing for tool calls.
|
||||
|
||||
```py
|
||||
{{< include "python/adk/agent.py" >}}
|
||||
```
|
||||
You can also add model-level (`before_model_callback`, `after_model_callback`) and agent-level (`before_agent_callback`, `after_agent_callback`) hooks to intercept messages at different stages of the execution loop.
|
||||
|
||||
For more information, see the [ADK Callbacks documentation](https://google.github.io/adk-docs/callbacks/types-of-callbacks/).
|
||||
Coming soon.
|
||||
{{% /tab %}}
|
||||
{{% tab header="Langchain" text=true %}}
|
||||
The following example demonstrates how to use `ToolboxClient` with LangChain's middleware to implement pre- and post- processing for tool calls.
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from copy import deepcopy
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.apps import App
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
from toolbox_adk import CredentialStrategy, ToolboxToolset, ToolboxTool
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking and
|
||||
cancellations. When the user searches for a hotel, mention it's name, id,
|
||||
location and price tier. Always mention hotel ids while performing any
|
||||
searches. This is very important for any operations. For any bookings or
|
||||
cancellations, please provide the appropriate confirmation. Be sure to
|
||||
update checkin or checkout dates if mentioned by the user.
|
||||
Don't ask for confirmations from the user.
|
||||
"""
|
||||
|
||||
|
||||
# Pre processing
|
||||
async def before_tool_callback(
|
||||
tool: ToolboxTool, args: Dict[str, Any], tool_context: ToolContext
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Callback fired before a tool is executed.
|
||||
Enforces business logic: Max stay duration is 14 days.
|
||||
"""
|
||||
tool_name = tool.name
|
||||
print(f"POLICY CHECK: Intercepting '{tool_name}'")
|
||||
|
||||
if tool_name == "update-hotel" and "checkin_date" in args and "checkout_date" in args:
|
||||
start = datetime.fromisoformat(args["checkin_date"])
|
||||
end = datetime.fromisoformat(args["checkout_date"])
|
||||
duration = (end - start).days
|
||||
|
||||
if duration > 14:
|
||||
print("BLOCKED: Stay too long")
|
||||
return {"result": "Error: Maximum stay duration is 14 days."}
|
||||
return None
|
||||
|
||||
|
||||
# Post processing
|
||||
async def after_tool_callback(
|
||||
tool: ToolboxTool,
|
||||
args: Dict[str, Any],
|
||||
tool_context: ToolContext,
|
||||
tool_response: Any,
|
||||
) -> Optional[Any]:
|
||||
"""
|
||||
Callback fired after a tool execution.
|
||||
Enriches response for successful bookings.
|
||||
"""
|
||||
if isinstance(tool_response, dict):
|
||||
result = tool_response.get("result", "")
|
||||
elif isinstance(tool_response, str):
|
||||
result = tool_response
|
||||
else:
|
||||
return None
|
||||
|
||||
tool_name = tool.name
|
||||
if isinstance(result, str) and "Error" not in result:
|
||||
if tool_name == "book-hotel":
|
||||
loyalty_bonus = 500
|
||||
enriched_result = f"Booking Confirmed!\n You earned {loyalty_bonus} Loyalty Points with this stay.\n\nSystem Details: {result}"
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
modified_response = deepcopy(tool_response)
|
||||
modified_response["result"] = enriched_result
|
||||
return modified_response
|
||||
else:
|
||||
return enriched_result
|
||||
return None
|
||||
|
||||
|
||||
async def run_chat_turn(
|
||||
runner: Runner, session_id: str, user_id: str, message_text: str
|
||||
):
|
||||
"""Executes a single chat turn and prints the interaction."""
|
||||
print(f"\nUSER: '{message_text}'")
|
||||
response_text = ""
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=types.Content(role="user", parts=[types.Part(text=message_text)]),
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
response_text += part.text
|
||||
|
||||
print(f"AI: {response_text}")
|
||||
|
||||
|
||||
async def main():
|
||||
toolset = ToolboxToolset(
|
||||
server_url="http://127.0.0.1:5000",
|
||||
toolset_name="my-toolset",
|
||||
credentials=CredentialStrategy.toolbox_identity(),
|
||||
)
|
||||
tools = await toolset.get_tools()
|
||||
root_agent = Agent(
|
||||
name="root_agent",
|
||||
model="gemini-2.5-flash",
|
||||
instruction=SYSTEM_PROMPT,
|
||||
tools=tools,
|
||||
# add any pre and post processing callbacks
|
||||
before_tool_callback=before_tool_callback,
|
||||
after_tool_callback=after_tool_callback,
|
||||
)
|
||||
app = App(root_agent=root_agent, name="my_agent")
|
||||
runner = Runner(app=app, session_service=InMemorySessionService())
|
||||
session_id = "test-session"
|
||||
user_id = "test-user"
|
||||
await runner.session_service.create_session(
|
||||
app_name=app.name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
# First turn: Successful booking
|
||||
await run_chat_turn(runner, session_id, user_id, "Book hotel with id 3.")
|
||||
print("-" * 50)
|
||||
# Second turn: Policy violation (stay > 14 days)
|
||||
await run_chat_turn(
|
||||
runner,
|
||||
session_id,
|
||||
user_id,
|
||||
"Book a hotel with id 5 with checkin date 2025-01-18 and checkout date 2025-02-10",
|
||||
)
|
||||
await toolset.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,3 +0,0 @@
|
||||
google-adk[toolbox]==1.23.0
|
||||
toolbox-adk==0.5.8
|
||||
google-genai==1.62.0
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "mcp-toolbox-for-databases",
|
||||
"version": "0.26.0",
|
||||
"version": "0.27.0",
|
||||
"description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.",
|
||||
"contextFileName": "MCP-TOOLBOX-EXTENSION.md"
|
||||
}
|
||||
1
go.mod
1
go.mod
@@ -21,6 +21,7 @@ require (
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.30.0
|
||||
github.com/apache/cassandra-gocql-driver/v2 v2.0.0
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.4.2
|
||||
github.com/couchbase/gocb/v2 v2.11.1
|
||||
github.com/couchbase/tools-common/http v1.0.9
|
||||
github.com/elastic/elastic-transport-go/v8 v8.8.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -800,6 +800,8 @@ github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWH
|
||||
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls=
|
||||
github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.4.2 h1:QB0ozDWQUUJ0GP8Zw63X/qHefPTCpLvtfCs6TLrPgyE=
|
||||
github.com/cockroachdb/cockroach-go/v2 v2.4.2/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0=
|
||||
github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4=
|
||||
github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
@@ -960,6 +962,8 @@ github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM=
|
||||
github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
|
||||
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
|
||||
@@ -47,6 +47,7 @@ var expectedToolSources = []string{
|
||||
"mysql",
|
||||
"neo4j",
|
||||
"oceanbase",
|
||||
"oracledb",
|
||||
"postgres",
|
||||
"serverless-spark",
|
||||
"singlestore",
|
||||
@@ -131,6 +132,8 @@ func TestGetPrebuiltTool(t *testing.T) {
|
||||
neo4jconfig := getOrFatal(t, "neo4j")
|
||||
healthcare_config := getOrFatal(t, "cloud-healthcare")
|
||||
snowflake_config := getOrFatal(t, "snowflake")
|
||||
oracle_config := getOrFatal(t,"oracledb")
|
||||
|
||||
if len(alloydb_omni_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch alloydb omni prebuilt tools yaml")
|
||||
}
|
||||
@@ -230,6 +233,10 @@ func TestGetPrebuiltTool(t *testing.T) {
|
||||
if len(snowflake_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch snowflake prebuilt tools yaml")
|
||||
}
|
||||
|
||||
if len(oracle_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch oracle prebuilt tools yaml")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailGetPrebuiltTool(t *testing.T) {
|
||||
|
||||
121
internal/prebuiltconfigs/tools/oracledb.yaml
Normal file
121
internal/prebuiltconfigs/tools/oracledb.yaml
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
oracle-source:
|
||||
kind: "oracle"
|
||||
connectionString: ${ORACLE_CONNECTION_STRING}
|
||||
walletLocation: ${ORACLE_WALLET:}
|
||||
user: ${ORACLE_USER}
|
||||
password: ${ORACLE_PASSWORD}
|
||||
useOCI: ${ORACLE_USE_OCI:false}
|
||||
|
||||
tools:
|
||||
|
||||
list_tables:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "Lists all user tables in the connected schema, including segment size, row count, and last analyzed date. Filters by a comma-separated list of names. If names are omitted, lists all tables in the current user's schema."
|
||||
statement: SELECT table_name from user_tables;
|
||||
|
||||
list_active_sessions:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "List the top N (default 50) currently running database sessions (STATUS='ACTIVE'), showing SID, OS User, Program, and the current SQL statement text."
|
||||
statement: SELECT
|
||||
s.sid,
|
||||
s.serial#,
|
||||
s.username,
|
||||
s.osuser,
|
||||
s.program,
|
||||
s.status,
|
||||
s.wait_class,
|
||||
s.event,
|
||||
sql.sql_text
|
||||
FROM
|
||||
v$session s,
|
||||
v$sql sql
|
||||
WHERE
|
||||
s.status = 'ACTIVE'
|
||||
AND s.sql_id = sql.sql_id (+)
|
||||
AND s.audsid != userenv('sessionid') -- Exclude current session
|
||||
ORDER BY s.last_call_et DESC
|
||||
FETCH FIRST COALESCE(10) ROWS ONLY;
|
||||
|
||||
get_query_plan:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "Generate a full execution plan for a single SQL statement. This can be used to analyze query performance without execution. Requires the SQL statement as input. following is an example EXPLAIN PLAN FOR {{&query}};"
|
||||
statement: SELECT PLAN_TABLE_OUTPUT FROM TABLE(DBMS_XPLAN.DISPLAY());
|
||||
|
||||
list_top_sql_by_resource:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "List the top N SQL statements from the library cache based on a chosen resource metric (CPU, I/O, or Elapsed Time), following is an example of the sql"
|
||||
statement: SELECT
|
||||
sql_id,
|
||||
executions,
|
||||
buffer_gets,
|
||||
disk_reads,
|
||||
cpu_time / 1000000 AS cpu_seconds,
|
||||
elapsed_time / 1000000 AS elapsed_seconds
|
||||
FROM
|
||||
v$sql
|
||||
FETCH FIRST 5 ROWS ONLY;
|
||||
|
||||
list_tablespace_usage:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "List tablespace names, total size, free space, and used percentage to monitor storage utilization."
|
||||
statement: SELECT
|
||||
t.tablespace_name,
|
||||
TO_CHAR(t.total_bytes / 1024 / 1024, '99,999.00') AS total_mb,
|
||||
TO_CHAR(SUM(d.bytes) / 1024 / 1024, '99,999.00') AS free_mb,
|
||||
TO_CHAR((t.total_bytes - SUM(d.bytes)) / t.total_bytes * 100, '99.00') AS used_pct
|
||||
FROM
|
||||
(SELECT tablespace_name, SUM(bytes) AS total_bytes FROM dba_data_files GROUP BY tablespace_name) t,
|
||||
dba_free_space d
|
||||
WHERE
|
||||
t.tablespace_name = d.tablespace_name (+)
|
||||
GROUP BY
|
||||
t.tablespace_name, t.total_bytes
|
||||
ORDER BY
|
||||
used_pct DESC;
|
||||
|
||||
list_invalid_objects:
|
||||
kind: oracle-sql
|
||||
source: oracle-source
|
||||
description: "Lists all database objects that are in an invalid state, requiring recompilation (e.g., procedures, functions, views)."
|
||||
statement: SELECT
|
||||
owner,
|
||||
object_type,
|
||||
object_name,
|
||||
status
|
||||
FROM
|
||||
dba_objects
|
||||
WHERE
|
||||
status = 'INVALID'
|
||||
AND owner NOT IN ('SYS', 'SYSTEM') -- Exclude system schemas for clarity
|
||||
ORDER BY
|
||||
owner, object_type, object_name;
|
||||
|
||||
toolsets:
|
||||
oracle_database_tools:
|
||||
- execute_sql
|
||||
- list_tables
|
||||
- list_active_sessions
|
||||
- get_query_plan
|
||||
- list_top_sql_by_resource
|
||||
- list_tablespace_usage
|
||||
- list_invalid_objects
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
@@ -216,7 +215,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("tool invocation not authorized. Please make sure your specify correct auth headers")
|
||||
err = fmt.Errorf("tool invocation not authorized. Please make sure you specify correct auth headers")
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
@@ -234,15 +233,28 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
params, err := parameters.ParseParams(tool.GetParameters(), data, claimsFromAuth)
|
||||
if err != nil {
|
||||
// If auth error, return 401
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error parsing authenticated parameters from ID token: %s", err))
|
||||
var clientServerErr *util.ClientServerError
|
||||
|
||||
// Return 401 Authentication errors
|
||||
if errors.As(err, &clientServerErr) && clientServerErr.Code == http.StatusUnauthorized {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("auth error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusUnauthorized))
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
|
||||
var agentErr *util.AgentError
|
||||
if errors.As(err, &agentErr) {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("agent validation error: %v", err))
|
||||
errMap := map[string]string{"error": err.Error()}
|
||||
errMarshal, _ := json.Marshal(errMap)
|
||||
|
||||
_ = render.Render(w, r, &resultResponse{Result: string(errMarshal)})
|
||||
return
|
||||
}
|
||||
|
||||
// Return 500 if it's a specific ClientServerError that isn't a 401, or any other unexpected error
|
||||
s.logger.ErrorContext(ctx, fmt.Sprintf("internal server error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
@@ -259,35 +271,51 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Determine what error to return to the users.
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
var statusCode int
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
// Upstream API auth error propagation
|
||||
switch {
|
||||
case strings.Contains(errStr, "Error 401"):
|
||||
statusCode = http.StatusUnauthorized
|
||||
case strings.Contains(errStr, "Error 403"):
|
||||
statusCode = http.StatusForbidden
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// Agent Errors -> 200 OK
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("Tool invocation agent error: %v", err))
|
||||
res = map[string]string{
|
||||
"error": err.Error(),
|
||||
}
|
||||
|
||||
case util.CategoryServer:
|
||||
// Server Errors -> Check the specific code inside
|
||||
var clientServerErr *util.ClientServerError
|
||||
statusCode := http.StatusInternalServerError // Default to 500
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code != 0 {
|
||||
statusCode = clientServerErr.Code
|
||||
}
|
||||
}
|
||||
|
||||
// Process auth error
|
||||
if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
// Propagate the original 401/403 error.
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err))
|
||||
// Token error, pass through 401/403
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("Client credentials lack authorization: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
return
|
||||
}
|
||||
// ADC lacking permission or credentials configuration error.
|
||||
internalErr := fmt.Errorf("unexpected auth error occured during Tool invocation: %w", err)
|
||||
s.logger.ErrorContext(ctx, internalErr.Error())
|
||||
_ = render.Render(w, r, newErrResponse(internalErr, http.StatusInternalServerError))
|
||||
// ADC/Config error, return 500
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation server error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, statusCode))
|
||||
return
|
||||
}
|
||||
err = fmt.Errorf("error while invoking tool: %w", err)
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusBadRequest))
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
s.logger.ErrorContext(ctx, fmt.Sprintf("Tool invocation unknown error: %v", err))
|
||||
_ = render.Render(w, r, newErrResponse(err, http.StatusInternalServerError))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
resMarshal, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -37,9 +36,11 @@ import (
|
||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type sseSession struct {
|
||||
@@ -117,6 +118,55 @@ type stdioSession struct {
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
// traceContextCarrier implements propagation.TextMapCarrier for extracting trace context from _meta
|
||||
type traceContextCarrier map[string]string
|
||||
|
||||
func (c traceContextCarrier) Get(key string) string {
|
||||
return c[key]
|
||||
}
|
||||
|
||||
func (c traceContextCarrier) Set(key, value string) {
|
||||
c[key] = value
|
||||
}
|
||||
|
||||
func (c traceContextCarrier) Keys() []string {
|
||||
keys := make([]string, 0, len(c))
|
||||
for k := range c {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// extractTraceContext extracts W3C Trace Context from params._meta
|
||||
func extractTraceContext(ctx context.Context, body []byte) context.Context {
|
||||
// Try to parse the request to extract _meta
|
||||
var req struct {
|
||||
Params struct {
|
||||
Meta struct {
|
||||
Traceparent string `json:"traceparent,omitempty"`
|
||||
Tracestate string `json:"tracestate,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// If traceparent is present, extract the context
|
||||
if req.Params.Meta.Traceparent != "" {
|
||||
carrier := traceContextCarrier{
|
||||
"traceparent": req.Params.Meta.Traceparent,
|
||||
}
|
||||
if req.Params.Meta.Tracestate != "" {
|
||||
carrier["tracestate"] = req.Params.Meta.Tracestate
|
||||
}
|
||||
return otel.GetTextMapPropagator().Extract(ctx, carrier)
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
func NewStdioSession(s *Server, stdin io.Reader, stdout io.Writer) *stdioSession {
|
||||
stdioSession := &stdioSession{
|
||||
server: s,
|
||||
@@ -143,18 +193,29 @@ func (s *stdioSession) readInputStream(ctx context.Context) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
v, res, err := processMcpMessage(ctx, []byte(line), s.server, s.protocol, "", "", nil)
|
||||
// This ensures the transport span becomes a child of the client span
|
||||
msgCtx := extractTraceContext(ctx, []byte(line))
|
||||
|
||||
// Create span for STDIO transport
|
||||
msgCtx, span := s.server.instrumentation.Tracer.Start(msgCtx, "toolbox/server/mcp/stdio",
|
||||
trace.WithSpanKind(trace.SpanKindServer),
|
||||
)
|
||||
defer span.End()
|
||||
|
||||
v, res, err := processMcpMessage(msgCtx, []byte(line), s.server, s.protocol, "", "", nil, "")
|
||||
if err != nil {
|
||||
// errors during the processing of message will generate a valid MCP Error response.
|
||||
// server can continue to run.
|
||||
s.server.logger.ErrorContext(ctx, err.Error())
|
||||
s.server.logger.ErrorContext(msgCtx, err.Error())
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
}
|
||||
|
||||
if v != "" {
|
||||
s.protocol = v
|
||||
}
|
||||
// no responses for notifications
|
||||
if res != nil {
|
||||
if err = s.write(ctx, res); err != nil {
|
||||
if err = s.write(msgCtx, res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -240,7 +301,9 @@ func mcpRouter(s *Server) (chi.Router, error) {
|
||||
|
||||
// sseHandler handles sse initialization and message.
|
||||
func sseHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse")
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp/sse",
|
||||
trace.WithSpanKind(trace.SpanKindServer),
|
||||
)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
sessionId := uuid.New().String()
|
||||
@@ -336,9 +399,27 @@ func methodNotAllowed(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
ctx, span := s.instrumentation.Tracer.Start(r.Context(), "toolbox/server/mcp")
|
||||
ctx := r.Context()
|
||||
ctx = util.WithLogger(ctx, s.logger)
|
||||
|
||||
// Read body first so we can extract trace context
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
|
||||
// This ensures the transport span becomes a child of the client span
|
||||
ctx = extractTraceContext(ctx, body)
|
||||
|
||||
// Create span for HTTP transport
|
||||
ctx, span := s.instrumentation.Tracer.Start(ctx, "toolbox/server/mcp/http",
|
||||
trace.WithSpanKind(trace.SpanKindServer),
|
||||
)
|
||||
r = r.WithContext(ctx)
|
||||
ctx = util.WithLogger(r.Context(), s.logger)
|
||||
|
||||
var sessionId, protocolVersion string
|
||||
var session *sseSession
|
||||
@@ -380,7 +461,6 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
s.logger.DebugContext(ctx, fmt.Sprintf("toolset name: %s", toolsetName))
|
||||
span.SetAttributes(attribute.String("toolset_name", toolsetName))
|
||||
|
||||
var err error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
@@ -399,17 +479,9 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
}()
|
||||
|
||||
// Read and returns a body from io.Reader
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
// Generate a new uuid if unable to decode
|
||||
id := uuid.New().String()
|
||||
s.logger.DebugContext(ctx, err.Error())
|
||||
render.JSON(w, r, jsonrpc.NewError(id, jsonrpc.PARSE_ERROR, err.Error(), nil))
|
||||
return
|
||||
}
|
||||
networkProtocolVersion := fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor)
|
||||
|
||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header)
|
||||
v, res, err := processMcpMessage(ctx, body, s, protocolVersion, toolsetName, promptsetName, r.Header, networkProtocolVersion)
|
||||
if err != nil {
|
||||
s.logger.DebugContext(ctx, fmt.Errorf("error processing message: %w", err).Error())
|
||||
}
|
||||
@@ -444,15 +516,12 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
code := rpcResponse.Error.Code
|
||||
switch code {
|
||||
case jsonrpc.INTERNAL_ERROR:
|
||||
// Map Internal RPC Error (-32603) to HTTP 500
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
case jsonrpc.INVALID_REQUEST:
|
||||
errStr := err.Error()
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 401") {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
} else if strings.Contains(errStr, "Error 403") {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
var clientServerErr *util.ClientServerError
|
||||
if errors.As(err, &clientServerErr) {
|
||||
w.WriteHeader(clientServerErr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -462,7 +531,7 @@ func httpHandler(s *Server, w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// processMcpMessage process the messages received from clients
|
||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header) (string, any, error) {
|
||||
func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVersion string, toolsetName string, promptsetName string, header http.Header, networkProtocolVersion string) (string, any, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return "", jsonrpc.NewError("", jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
@@ -498,31 +567,95 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Create method-specific span with semantic conventions
|
||||
// Note: Trace context is already extracted and set in ctx by the caller
|
||||
ctx, span := s.instrumentation.Tracer.Start(ctx, baseMessage.Method,
|
||||
trace.WithSpanKind(trace.SpanKindServer),
|
||||
)
|
||||
defer span.End()
|
||||
|
||||
// Determine network transport and protocol based on header presence
|
||||
networkTransport := "pipe" // default for stdio
|
||||
networkProtocolName := "stdio"
|
||||
if header != nil {
|
||||
networkTransport = "tcp" // HTTP/SSE transport
|
||||
networkProtocolName = "http"
|
||||
}
|
||||
|
||||
// Set required semantic attributes for span according to OTEL MCP semcov
|
||||
// ref: https://opentelemetry.io/docs/specs/semconv/gen-ai/mcp/#server
|
||||
span.SetAttributes(
|
||||
attribute.String("mcp.method.name", baseMessage.Method),
|
||||
attribute.String("network.transport", networkTransport),
|
||||
attribute.String("network.protocol.name", networkProtocolName),
|
||||
)
|
||||
|
||||
// Set network protocol version if available
|
||||
if networkProtocolVersion != "" {
|
||||
span.SetAttributes(attribute.String("network.protocol.version", networkProtocolVersion))
|
||||
}
|
||||
|
||||
// Set MCP protocol version if available
|
||||
if protocolVersion != "" {
|
||||
span.SetAttributes(attribute.String("mcp.protocol.version", protocolVersion))
|
||||
}
|
||||
|
||||
// Set request ID
|
||||
if baseMessage.Id != nil {
|
||||
span.SetAttributes(attribute.String("jsonrpc.request.id", fmt.Sprintf("%v", baseMessage.Id)))
|
||||
}
|
||||
|
||||
// Set toolset name
|
||||
span.SetAttributes(attribute.String("toolset.name", toolsetName))
|
||||
|
||||
// Check if message is a notification
|
||||
if baseMessage.Id == nil {
|
||||
err := mcp.NotificationHandler(ctx, body)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
}
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
// Process the method
|
||||
switch baseMessage.Method {
|
||||
case mcputil.INITIALIZE:
|
||||
res, v, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
|
||||
result, version, err := mcp.InitializeResponse(ctx, baseMessage.Id, body, s.version)
|
||||
if err != nil {
|
||||
return "", res, err
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
|
||||
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||
}
|
||||
return v, res, err
|
||||
return "", result, err
|
||||
}
|
||||
span.SetAttributes(attribute.String("mcp.protocol.version", version))
|
||||
return version, result, err
|
||||
default:
|
||||
toolset, ok := s.ResourceMgr.GetToolset(toolsetName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("toolset does not exist")
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
err := fmt.Errorf("toolset does not exist")
|
||||
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||
return "", rpcErr, err
|
||||
}
|
||||
promptset, ok := s.ResourceMgr.GetPromptset(promptsetName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("promptset does not exist")
|
||||
return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
err := fmt.Errorf("promptset does not exist")
|
||||
rpcErr := jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||
return "", rpcErr, err
|
||||
}
|
||||
res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header)
|
||||
return "", res, err
|
||||
result, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
// Set error.type based on JSON-RPC error code
|
||||
if rpcErr, ok := result.(jsonrpc.JSONRPCError); ok {
|
||||
span.SetAttributes(attribute.Int("jsonrpc.error.code", rpcErr.Error.Code))
|
||||
span.SetAttributes(attribute.String("error.type", rpcErr.Error.String()))
|
||||
}
|
||||
}
|
||||
return "", result, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,9 @@ type Request struct {
|
||||
// notifications. The receiver is not obligated to provide these
|
||||
// notifications.
|
||||
ProgressToken ProgressToken `json:"progressToken,omitempty"`
|
||||
// W3C Trace Context fields for distributed tracing
|
||||
Traceparent string `json:"traceparent,omitempty"`
|
||||
Tracestate string `json:"tracestate,omitempty"`
|
||||
} `json:"_meta,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
@@ -97,6 +100,24 @@ type Error struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// String returns the error type as a string based on the error code.
|
||||
func (e Error) String() string {
|
||||
switch e.Code {
|
||||
case METHOD_NOT_FOUND:
|
||||
return "method_not_found"
|
||||
case INVALID_PARAMS:
|
||||
return "invalid_params"
|
||||
case INTERNAL_ERROR:
|
||||
return "internal_error"
|
||||
case PARSE_ERROR:
|
||||
return "parse_error"
|
||||
case INVALID_REQUEST:
|
||||
return "invalid_request"
|
||||
default:
|
||||
return "jsonrpc_error"
|
||||
}
|
||||
}
|
||||
|
||||
// JSONRPCError represents a non-successful (error) response to a request.
|
||||
type JSONRPCError struct {
|
||||
Jsonrpc string `json:"jsonrpc"`
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -29,6 +28,8 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
@@ -102,6 +103,14 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||
span.SetAttributes(
|
||||
attribute.String("gen_ai.tool.name", toolName),
|
||||
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||
)
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
@@ -124,7 +133,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
err := util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +186,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -194,21 +212,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
@@ -218,6 +228,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
@@ -288,6 +320,11 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
||||
|
||||
promptName := req.Params.Name
|
||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||
if !ok {
|
||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -29,6 +28,8 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
@@ -102,6 +103,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||
span.SetAttributes(
|
||||
attribute.String("gen_ai.tool.name", toolName),
|
||||
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||
)
|
||||
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
@@ -124,7 +134,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
err := util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,7 +187,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -194,20 +213,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
@@ -217,8 +229,29 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
@@ -287,6 +320,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
||||
|
||||
promptName := req.Params.Name
|
||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||
|
||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||
if !ok {
|
||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -29,6 +28,8 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
@@ -95,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||
span.SetAttributes(
|
||||
attribute.String("gen_ai.tool.name", toolName),
|
||||
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||
)
|
||||
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
@@ -117,7 +127,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
err := util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +180,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -187,20 +206,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
@@ -210,6 +222,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
@@ -280,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
||||
|
||||
promptName := req.Params.Name
|
||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||
|
||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||
if !ok {
|
||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
@@ -29,6 +28,8 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
@@ -95,6 +96,15 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", TOOLS_CALL, toolName))
|
||||
span.SetAttributes(
|
||||
attribute.String("gen_ai.tool.name", toolName),
|
||||
attribute.String("gen_ai.operation.name", "execute_tool"),
|
||||
)
|
||||
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
@@ -117,7 +127,12 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
err := util.NewClientServerError(
|
||||
"missing access token in the 'Authorization' header",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,7 +180,11 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
err = util.NewClientServerError(
|
||||
"unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
http.StatusUnauthorized,
|
||||
nil,
|
||||
)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
@@ -187,20 +206,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
var tbErr util.ToolboxError
|
||||
|
||||
if errors.As(err, &tbErr) {
|
||||
switch tbErr.Category() {
|
||||
case util.CategoryAgent:
|
||||
// MCP - Tool execution error
|
||||
// Return SUCCESS but with IsError: true
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
@@ -210,6 +222,28 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
|
||||
case util.CategoryServer:
|
||||
// MCP Spec - Protocol error
|
||||
// Return JSON-RPC ERROR
|
||||
var clientServerErr *util.ClientServerError
|
||||
rpcCode := jsonrpc.INTERNAL_ERROR // Default to Internal Error (-32603)
|
||||
|
||||
if errors.As(err, &clientServerErr) {
|
||||
if clientServerErr.Code == http.StatusUnauthorized || clientServerErr.Code == http.StatusForbidden {
|
||||
if clientAuth {
|
||||
rpcCode = jsonrpc.INVALID_REQUEST
|
||||
} else {
|
||||
rpcCode = jsonrpc.INTERNAL_ERROR
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonrpc.NewError(id, rpcCode, err.Error(), nil), err
|
||||
}
|
||||
} else {
|
||||
// Unknown error -> 500
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
@@ -280,6 +314,12 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *r
|
||||
|
||||
promptName := req.Params.Name
|
||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||
|
||||
// Update span name and set gen_ai attributes
|
||||
span := trace.SpanFromContext(ctx)
|
||||
span.SetName(fmt.Sprintf("%s %s", PROMPTS_GET, promptName))
|
||||
span.SetAttributes(attribute.String("gen_ai.prompt.name", promptName))
|
||||
|
||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||
if !ok {
|
||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||
|
||||
@@ -231,7 +231,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
||||
"id": "tools-call-tool4",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -320,7 +320,7 @@ func TestMcpEndpointWithoutInitialized(t *testing.T) {
|
||||
Params: map[string]any{
|
||||
"name": "prompt2",
|
||||
"arguments": map[string]any{
|
||||
"arg1": 42, // prompt2 expects a string, we send a number
|
||||
"arg1": 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -834,7 +834,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
"id": "tools-call-tool4",
|
||||
"error": map[string]any{
|
||||
"code": -32600.0,
|
||||
"message": "unauthorized Tool call: Please make sure your specify correct auth headers: unauthorized",
|
||||
"message": "unauthorized Tool call: Please make sure you specify correct auth headers",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -34,7 +35,7 @@ type MockTool struct {
|
||||
requiresClientAuthrorization bool
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) {
|
||||
func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, util.ToolboxError) {
|
||||
mock := []any{t.Name}
|
||||
return mock, nil
|
||||
}
|
||||
|
||||
@@ -361,7 +361,11 @@ func (s *Source) GetOperations(ctx context.Context, project, location, operation
|
||||
}
|
||||
}
|
||||
|
||||
return string(opBytes), nil
|
||||
var result any
|
||||
if err := json.Unmarshal(opBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal operation bytes: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay))
|
||||
}
|
||||
|
||||
421
internal/sources/cockroachdb/cockroachdb.go
Normal file
421
internal/sources/cockroachdb/cockroachdb.go
Normal file
@@ -0,0 +1,421 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
crdbpgx "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceType string = "cockroachdb"
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceType, newConfig) {
|
||||
panic(fmt.Sprintf("source type %q already registered", SourceType))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
// MCP compliance: Read-only by default, require explicit opt-in for writes
|
||||
actual := Config{
|
||||
Name: name,
|
||||
MaxRetries: 5,
|
||||
RetryBaseDelay: "500ms",
|
||||
ReadOnlyMode: true, // MCP requirement: read-only by default
|
||||
EnableWriteMode: false, // Must be explicitly enabled
|
||||
MaxRowLimit: 1000, // MCP requirement: limit query results
|
||||
QueryTimeoutSec: 30, // MCP requirement: prevent long-running queries
|
||||
EnableTelemetry: true, // MCP requirement: observability
|
||||
TelemetryVerbose: false,
|
||||
}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Security validation: If EnableWriteMode is true, ReadOnlyMode should be false
|
||||
if actual.EnableWriteMode {
|
||||
actual.ReadOnlyMode = false
|
||||
}
|
||||
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Type string `yaml:"type" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
QueryParams map[string]string `yaml:"queryParams"`
|
||||
MaxRetries int `yaml:"maxRetries"`
|
||||
RetryBaseDelay string `yaml:"retryBaseDelay"`
|
||||
|
||||
// MCP Security Features
|
||||
ReadOnlyMode bool `yaml:"readOnlyMode"` // Default: true (enforced in Initialize)
|
||||
EnableWriteMode bool `yaml:"enableWriteMode"` // Explicit opt-in for write operations
|
||||
MaxRowLimit int `yaml:"maxRowLimit"` // Default: 1000
|
||||
QueryTimeoutSec int `yaml:"queryTimeoutSec"` // Default: 30
|
||||
|
||||
// Observability
|
||||
EnableTelemetry bool `yaml:"enableTelemetry"` // Default: true
|
||||
TelemetryVerbose bool `yaml:"telemetryVerbose"` // Default: false
|
||||
ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
retryBaseDelay, err := time.ParseDuration(r.RetryBaseDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid retryBaseDelay: %w", err)
|
||||
}
|
||||
|
||||
pool, err := initCockroachDBConnectionPoolWithRetry(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.MaxRetries, retryBaseDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Pool: pool,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s *Source) SourceType() string {
|
||||
return SourceType
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) CockroachDBPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func (s *Source) PostgresPool() *pgxpool.Pool {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
// ExecuteTxWithRetry executes a function within a transaction with automatic retry logic
|
||||
// using the official CockroachDB retry mechanism from cockroach-go/v2
|
||||
func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error) error {
|
||||
return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{}, fn)
|
||||
}
|
||||
|
||||
// Query executes a query using the connection pool with MCP security enforcement.
|
||||
// For read-only queries, connection-level retry is sufficient.
|
||||
// For write operations requiring transaction retry, use ExecuteTxWithRetry directly.
|
||||
// Note: Callers should manage context timeouts as needed.
|
||||
func (s *Source) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
||||
// MCP Security Check 1: Enforce write operation restrictions
|
||||
if err := s.CanExecuteWrite(sql); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// MCP Security Check 2: Apply query limits (row limit)
|
||||
modifiedSQL, err := s.ApplyQueryLimits(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Pool.Query(ctx, modifiedSQL, args...)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MCP Security & Observability Features
|
||||
// ============================================================================
|
||||
|
||||
// TelemetryEvent represents a structured telemetry event for MCP tool calls
|
||||
type TelemetryEvent struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ToolName string `json:"tool_name"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
Database string `json:"database"`
|
||||
User string `json:"user"`
|
||||
SQLRedacted string `json:"sql_redacted"` // Query with values redacted
|
||||
Status string `json:"status"` // "success" | "failure"
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorMsg string `json:"error_msg,omitempty"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
RowsAffected int64 `json:"rows_affected,omitempty"`
|
||||
Metadata map[string]string `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// StructuredError represents an MCP-compliant error with error codes
|
||||
type StructuredError struct {
|
||||
Code string `json:"error_code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
func (e *StructuredError) Error() string {
|
||||
return fmt.Sprintf("[%s] %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// MCP Error Codes
|
||||
const (
|
||||
ErrCodeUnauthorized = "CRDB_UNAUTHORIZED"
|
||||
ErrCodeReadOnlyViolation = "CRDB_READONLY_VIOLATION"
|
||||
ErrCodeQueryTimeout = "CRDB_QUERY_TIMEOUT"
|
||||
ErrCodeRowLimitExceeded = "CRDB_ROW_LIMIT_EXCEEDED"
|
||||
ErrCodeInvalidSQL = "CRDB_INVALID_SQL"
|
||||
ErrCodeConnectionFailed = "CRDB_CONNECTION_FAILED"
|
||||
ErrCodeWriteModeRequired = "CRDB_WRITE_MODE_REQUIRED"
|
||||
ErrCodeQueryExecutionFailed = "CRDB_QUERY_EXECUTION_FAILED"
|
||||
)
|
||||
|
||||
// SQLStatementType represents the type of SQL statement
|
||||
type SQLStatementType int
|
||||
|
||||
const (
|
||||
SQLTypeUnknown SQLStatementType = iota
|
||||
SQLTypeSelect
|
||||
SQLTypeInsert
|
||||
SQLTypeUpdate
|
||||
SQLTypeDelete
|
||||
SQLTypeDDL // CREATE, ALTER, DROP
|
||||
SQLTypeTruncate
|
||||
SQLTypeExplain
|
||||
SQLTypeShow
|
||||
SQLTypeSet
|
||||
)
|
||||
|
||||
// ClassifySQL analyzes a SQL statement and returns its type
|
||||
func ClassifySQL(sql string) SQLStatementType {
|
||||
// Normalize: trim and convert to uppercase for analysis
|
||||
normalized := strings.TrimSpace(strings.ToUpper(sql))
|
||||
|
||||
if normalized == "" {
|
||||
return SQLTypeUnknown
|
||||
}
|
||||
|
||||
// Remove comments
|
||||
normalized = regexp.MustCompile(`--.*`).ReplaceAllString(normalized, "")
|
||||
normalized = regexp.MustCompile(`/\*.*?\*/`).ReplaceAllString(normalized, "")
|
||||
normalized = strings.TrimSpace(normalized)
|
||||
|
||||
// Check statement type
|
||||
switch {
|
||||
case strings.HasPrefix(normalized, "SELECT"):
|
||||
return SQLTypeSelect
|
||||
case strings.HasPrefix(normalized, "INSERT"):
|
||||
return SQLTypeInsert
|
||||
case strings.HasPrefix(normalized, "UPDATE"):
|
||||
return SQLTypeUpdate
|
||||
case strings.HasPrefix(normalized, "DELETE"):
|
||||
return SQLTypeDelete
|
||||
case strings.HasPrefix(normalized, "TRUNCATE"):
|
||||
return SQLTypeTruncate
|
||||
case strings.HasPrefix(normalized, "CREATE"):
|
||||
return SQLTypeDDL
|
||||
case strings.HasPrefix(normalized, "ALTER"):
|
||||
return SQLTypeDDL
|
||||
case strings.HasPrefix(normalized, "DROP"):
|
||||
return SQLTypeDDL
|
||||
case strings.HasPrefix(normalized, "EXPLAIN"):
|
||||
return SQLTypeExplain
|
||||
case strings.HasPrefix(normalized, "SHOW"):
|
||||
return SQLTypeShow
|
||||
case strings.HasPrefix(normalized, "SET"):
|
||||
return SQLTypeSet
|
||||
default:
|
||||
return SQLTypeUnknown
|
||||
}
|
||||
}
|
||||
|
||||
// IsWriteOperation returns true if the SQL statement modifies data
|
||||
func IsWriteOperation(sqlType SQLStatementType) bool {
|
||||
switch sqlType {
|
||||
case SQLTypeInsert, SQLTypeUpdate, SQLTypeDelete, SQLTypeTruncate, SQLTypeDDL:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// IsReadOnlyMode returns whether the source is in read-only mode
|
||||
func (s *Source) IsReadOnlyMode() bool {
|
||||
return s.ReadOnlyMode && !s.EnableWriteMode
|
||||
}
|
||||
|
||||
// CanExecuteWrite checks if a write operation is allowed
|
||||
func (s *Source) CanExecuteWrite(sql string) error {
|
||||
sqlType := ClassifySQL(sql)
|
||||
|
||||
if IsWriteOperation(sqlType) && s.IsReadOnlyMode() {
|
||||
return &StructuredError{
|
||||
Code: ErrCodeReadOnlyViolation,
|
||||
Message: "Write operations are not allowed in read-only mode. Set enableWriteMode: true to allow writes.",
|
||||
Details: map[string]any{
|
||||
"sql_type": sqlType,
|
||||
"read_only_mode": s.ReadOnlyMode,
|
||||
"enable_write_mode": s.EnableWriteMode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyQueryLimits applies row limits to a SQL query for MCP security compliance.
|
||||
// Context timeout management is the responsibility of the caller (following Go best practices).
|
||||
// Returns potentially modified SQL with LIMIT clause for SELECT queries.
|
||||
func (s *Source) ApplyQueryLimits(sql string) (string, error) {
|
||||
sqlType := ClassifySQL(sql)
|
||||
|
||||
// Apply row limit only to SELECT queries
|
||||
if sqlType == SQLTypeSelect && s.MaxRowLimit > 0 {
|
||||
// Check if query already has LIMIT clause
|
||||
normalized := strings.ToUpper(sql)
|
||||
if !strings.Contains(normalized, " LIMIT ") {
|
||||
// Add LIMIT clause - trim trailing whitespace and semicolon
|
||||
sql = strings.TrimSpace(sql)
|
||||
sql = strings.TrimSuffix(sql, ";")
|
||||
sql = fmt.Sprintf("%s LIMIT %d", sql, s.MaxRowLimit)
|
||||
}
|
||||
}
|
||||
|
||||
return sql, nil
|
||||
}
|
||||
|
||||
// RedactSQL redacts sensitive values from SQL for telemetry
|
||||
func RedactSQL(sql string) string {
|
||||
// Redact string literals
|
||||
sql = regexp.MustCompile(`'[^']*'`).ReplaceAllString(sql, "'***'")
|
||||
|
||||
// Redact numbers that might be sensitive
|
||||
sql = regexp.MustCompile(`\b\d{10,}\b`).ReplaceAllString(sql, "***")
|
||||
|
||||
return sql
|
||||
}
|
||||
|
||||
// EmitTelemetry logs a telemetry event in structured JSON format
|
||||
func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) {
|
||||
if !s.EnableTelemetry {
|
||||
return
|
||||
}
|
||||
|
||||
// Set cluster ID if not already set
|
||||
if event.ClusterID == "" {
|
||||
event.ClusterID = s.ClusterID
|
||||
if event.ClusterID == "" {
|
||||
event.ClusterID = s.Database // Fallback to database name
|
||||
}
|
||||
}
|
||||
|
||||
// Set database and user
|
||||
if event.Database == "" {
|
||||
event.Database = s.Database
|
||||
}
|
||||
if event.User == "" {
|
||||
event.User = s.User
|
||||
}
|
||||
|
||||
// Log as structured JSON
|
||||
if s.TelemetryVerbose {
|
||||
jsonBytes, _ := json.Marshal(event)
|
||||
slog.Info("CockroachDB MCP Telemetry", "event", string(jsonBytes))
|
||||
} else {
|
||||
// Minimal logging
|
||||
slog.Info("CockroachDB MCP",
|
||||
"tool", event.ToolName,
|
||||
"status", event.Status,
|
||||
"latency_ms", event.LatencyMs,
|
||||
"error_code", event.ErrorCode,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) {
|
||||
//nolint:all
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceType, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
userAgent = "genai-toolbox"
|
||||
}
|
||||
if queryParams == nil {
|
||||
queryParams = make(map[string]string)
|
||||
}
|
||||
if _, ok := queryParams["application_name"]; !ok {
|
||||
queryParams["application_name"] = userAgent
|
||||
}
|
||||
|
||||
connURL := &url.URL{
|
||||
Scheme: "postgres",
|
||||
User: url.UserPassword(user, pass),
|
||||
Host: fmt.Sprintf("%s:%s", host, port),
|
||||
Path: dbname,
|
||||
RawQuery: ConvertParamMapToRawQuery(queryParams),
|
||||
}
|
||||
|
||||
var pool *pgxpool.Pool
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
pool, err = pgxpool.New(ctx, connURL.String())
|
||||
if err == nil {
|
||||
err = pool.Ping(ctx)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
if attempt < maxRetries {
|
||||
backoff := baseDelay * time.Duration(math.Pow(2, float64(attempt)))
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to connect to CockroachDB after %d retries: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
func ConvertParamMapToRawQuery(queryParams map[string]string) string {
|
||||
values := url.Values{}
|
||||
for k, v := range queryParams {
|
||||
values.Add(k, v)
|
||||
}
|
||||
return values.Encode()
|
||||
}
|
||||
224
internal/sources/cockroachdb/cockroachdb_test.go
Normal file
224
internal/sources/cockroachdb/cockroachdb_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
func TestCockroachDBSourceConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
yaml string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
yaml: `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
password: ""
|
||||
database: defaultdb
|
||||
maxRetries: 5
|
||||
retryBaseDelay: 500ms
|
||||
queryParams:
|
||||
sslmode: disable
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "with optional queryParams",
|
||||
yaml: `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
password: testpass
|
||||
database: testdb
|
||||
queryParams:
|
||||
sslmode: require
|
||||
sslcert: /path/to/cert
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "with custom retry settings",
|
||||
yaml: `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
password: ""
|
||||
database: defaultdb
|
||||
maxRetries: 10
|
||||
retryBaseDelay: 1s
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "without password (insecure mode)",
|
||||
yaml: `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
database: defaultdb
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
decoder := yaml.NewDecoder(strings.NewReader(tt.yaml))
|
||||
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("expected config but got nil")
|
||||
}
|
||||
|
||||
// Verify it's the right type
|
||||
cockroachCfg, ok := cfg.(Config)
|
||||
if !ok {
|
||||
t.Fatalf("expected Config type, got %T", cfg)
|
||||
}
|
||||
|
||||
// Verify SourceConfigType
|
||||
if cockroachCfg.SourceConfigType() != SourceType {
|
||||
t.Errorf("expected SourceConfigType %q, got %q", SourceType, cockroachCfg.SourceConfigType())
|
||||
}
|
||||
|
||||
t.Logf("✅ Config parsed successfully: %+v", cockroachCfg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCockroachDBSourceType(t *testing.T) {
|
||||
yamlContent := `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
password: ""
|
||||
database: defaultdb
|
||||
`
|
||||
decoder := yaml.NewDecoder(strings.NewReader(yamlContent))
|
||||
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.SourceConfigType() != "cockroachdb" {
|
||||
t.Errorf("expected SourceConfigType 'cockroachdb', got %q", cfg.SourceConfigType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCockroachDBDefaultValues(t *testing.T) {
|
||||
yamlContent := `
|
||||
name: test-cockroachdb
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
password: ""
|
||||
database: defaultdb
|
||||
`
|
||||
decoder := yaml.NewDecoder(strings.NewReader(yamlContent))
|
||||
cfg, err := newConfig(context.Background(), "test", decoder)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create config: %v", err)
|
||||
}
|
||||
|
||||
cockroachCfg, ok := cfg.(Config)
|
||||
if !ok {
|
||||
t.Fatalf("expected Config type")
|
||||
}
|
||||
|
||||
// Check default values
|
||||
if cockroachCfg.MaxRetries != 5 {
|
||||
t.Errorf("expected default MaxRetries 5, got %d", cockroachCfg.MaxRetries)
|
||||
}
|
||||
|
||||
if cockroachCfg.RetryBaseDelay != "500ms" {
|
||||
t.Errorf("expected default RetryBaseDelay '500ms', got %q", cockroachCfg.RetryBaseDelay)
|
||||
}
|
||||
|
||||
t.Logf("✅ Default values set correctly")
|
||||
}
|
||||
|
||||
func TestConvertParamMapToRawQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params map[string]string
|
||||
want []string // Expected substrings in any order
|
||||
}{
|
||||
{
|
||||
name: "empty params",
|
||||
params: map[string]string{},
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "single param",
|
||||
params: map[string]string{
|
||||
"sslmode": "disable",
|
||||
},
|
||||
want: []string{"sslmode=disable"},
|
||||
},
|
||||
{
|
||||
name: "multiple params",
|
||||
params: map[string]string{
|
||||
"sslmode": "require",
|
||||
"application_name": "test-app",
|
||||
},
|
||||
want: []string{"sslmode=require", "application_name=test-app"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertParamMapToRawQuery(tt.params)
|
||||
|
||||
if len(tt.want) == 0 {
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string, got %q", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that all expected substrings are in the result
|
||||
for _, want := range tt.want {
|
||||
if !contains(result, want) {
|
||||
t.Errorf("expected result to contain %q, got %q", want, result)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("✅ Query string: %s", result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
455
internal/sources/cockroachdb/security_test.go
Normal file
455
internal/sources/cockroachdb/security_test.go
Normal file
@@ -0,0 +1,455 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// TestClassifySQL tests SQL statement classification
|
||||
func TestClassifySQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
expected SQLStatementType
|
||||
}{
|
||||
{"SELECT", "SELECT * FROM users", SQLTypeSelect},
|
||||
{"SELECT with spaces", " SELECT * FROM users ", SQLTypeSelect},
|
||||
{"SELECT with comment", "-- comment\nSELECT * FROM users", SQLTypeSelect},
|
||||
{"INSERT", "INSERT INTO users (name) VALUES ('alice')", SQLTypeInsert},
|
||||
{"UPDATE", "UPDATE users SET name='bob' WHERE id=1", SQLTypeUpdate},
|
||||
{"DELETE", "DELETE FROM users WHERE id=1", SQLTypeDelete},
|
||||
{"CREATE TABLE", "CREATE TABLE users (id UUID PRIMARY KEY)", SQLTypeDDL},
|
||||
{"ALTER TABLE", "ALTER TABLE users ADD COLUMN email STRING", SQLTypeDDL},
|
||||
{"DROP TABLE", "DROP TABLE users", SQLTypeDDL},
|
||||
{"TRUNCATE", "TRUNCATE TABLE users", SQLTypeTruncate},
|
||||
{"EXPLAIN", "EXPLAIN SELECT * FROM users", SQLTypeExplain},
|
||||
{"SHOW", "SHOW TABLES", SQLTypeShow},
|
||||
{"SET", "SET application_name = 'myapp'", SQLTypeSet},
|
||||
{"Empty", "", SQLTypeUnknown},
|
||||
{"Lowercase select", "select * from users", SQLTypeSelect},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ClassifySQL(tt.sql)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ClassifySQL(%q) = %v, want %v", tt.sql, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsWriteOperation tests write operation detection
|
||||
func TestIsWriteOperation(t *testing.T) {
|
||||
tests := []struct {
|
||||
sqlType SQLStatementType
|
||||
expected bool
|
||||
}{
|
||||
{SQLTypeSelect, false},
|
||||
{SQLTypeInsert, true},
|
||||
{SQLTypeUpdate, true},
|
||||
{SQLTypeDelete, true},
|
||||
{SQLTypeTruncate, true},
|
||||
{SQLTypeDDL, true},
|
||||
{SQLTypeExplain, false},
|
||||
{SQLTypeShow, false},
|
||||
{SQLTypeSet, false},
|
||||
{SQLTypeUnknown, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.sqlType.String(), func(t *testing.T) {
|
||||
result := IsWriteOperation(tt.sqlType)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsWriteOperation(%v) = %v, want %v", tt.sqlType, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for SQLStatementType to string
|
||||
func (s SQLStatementType) String() string {
|
||||
switch s {
|
||||
case SQLTypeSelect:
|
||||
return "SELECT"
|
||||
case SQLTypeInsert:
|
||||
return "INSERT"
|
||||
case SQLTypeUpdate:
|
||||
return "UPDATE"
|
||||
case SQLTypeDelete:
|
||||
return "DELETE"
|
||||
case SQLTypeDDL:
|
||||
return "DDL"
|
||||
case SQLTypeTruncate:
|
||||
return "TRUNCATE"
|
||||
case SQLTypeExplain:
|
||||
return "EXPLAIN"
|
||||
case SQLTypeShow:
|
||||
return "SHOW"
|
||||
case SQLTypeSet:
|
||||
return "SET"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// TestCanExecuteWrite tests write operation enforcement
|
||||
func TestCanExecuteWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
readOnlyMode bool
|
||||
enableWriteMode bool
|
||||
sql string
|
||||
expectError bool
|
||||
errorCode string
|
||||
}{
|
||||
{
|
||||
name: "SELECT in read-only mode",
|
||||
readOnlyMode: true,
|
||||
enableWriteMode: false,
|
||||
sql: "SELECT * FROM users",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "INSERT in read-only mode",
|
||||
readOnlyMode: true,
|
||||
enableWriteMode: false,
|
||||
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||
expectError: true,
|
||||
errorCode: ErrCodeReadOnlyViolation,
|
||||
},
|
||||
{
|
||||
name: "INSERT with write mode enabled",
|
||||
readOnlyMode: false,
|
||||
enableWriteMode: true,
|
||||
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "CREATE TABLE in read-only mode",
|
||||
readOnlyMode: true,
|
||||
enableWriteMode: false,
|
||||
sql: "CREATE TABLE test (id UUID PRIMARY KEY)",
|
||||
expectError: true,
|
||||
errorCode: ErrCodeReadOnlyViolation,
|
||||
},
|
||||
{
|
||||
name: "CREATE TABLE with write mode enabled",
|
||||
readOnlyMode: false,
|
||||
enableWriteMode: true,
|
||||
sql: "CREATE TABLE test (id UUID PRIMARY KEY)",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
source := &Source{
|
||||
Config: Config{
|
||||
ReadOnlyMode: tt.readOnlyMode,
|
||||
EnableWriteMode: tt.enableWriteMode,
|
||||
},
|
||||
}
|
||||
|
||||
err := source.CanExecuteWrite(tt.sql)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
structErr, ok := err.(*StructuredError)
|
||||
if !ok {
|
||||
t.Errorf("Expected StructuredError but got %T", err)
|
||||
return
|
||||
}
|
||||
|
||||
if structErr.Code != tt.errorCode {
|
||||
t.Errorf("Expected error code %s but got %s", tt.errorCode, structErr.Code)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyQueryLimits tests query limit application
|
||||
func TestApplyQueryLimits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
maxRowLimit int
|
||||
expectedSQL string
|
||||
shouldAddLimit bool
|
||||
}{
|
||||
{
|
||||
name: "SELECT without LIMIT",
|
||||
sql: "SELECT * FROM users",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||
shouldAddLimit: true,
|
||||
},
|
||||
{
|
||||
name: "SELECT with existing LIMIT",
|
||||
sql: "SELECT * FROM users LIMIT 50",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "SELECT * FROM users LIMIT 50",
|
||||
shouldAddLimit: false,
|
||||
},
|
||||
{
|
||||
name: "SELECT without LIMIT and semicolon",
|
||||
sql: "SELECT * FROM users;",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||
shouldAddLimit: true,
|
||||
},
|
||||
{
|
||||
name: "SELECT with trailing newline and semicolon",
|
||||
sql: "SELECT * FROM users;\n",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "SELECT * FROM users LIMIT 100",
|
||||
shouldAddLimit: true,
|
||||
},
|
||||
{
|
||||
name: "SELECT with multiline and semicolon",
|
||||
sql: "\n\tSELECT *\n\tFROM users\n\tORDER BY id;\n",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "SELECT *\n\tFROM users\n\tORDER BY id LIMIT 100",
|
||||
shouldAddLimit: true,
|
||||
},
|
||||
{
|
||||
name: "INSERT should not have LIMIT added",
|
||||
sql: "INSERT INTO users (name) VALUES ('alice')",
|
||||
maxRowLimit: 100,
|
||||
expectedSQL: "INSERT INTO users (name) VALUES ('alice')",
|
||||
shouldAddLimit: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
source := &Source{
|
||||
Config: Config{
|
||||
MaxRowLimit: tt.maxRowLimit,
|
||||
QueryTimeoutSec: 0, // Timeout now managed by caller
|
||||
},
|
||||
}
|
||||
|
||||
modifiedSQL, err := source.ApplyQueryLimits(tt.sql)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if modifiedSQL != tt.expectedSQL {
|
||||
t.Errorf("Expected SQL:\n%s\nGot:\n%s", tt.expectedSQL, modifiedSQL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyQueryTimeout tests that timeout is managed by caller (not source)
|
||||
func TestApplyQueryTimeout(t *testing.T) {
|
||||
source := &Source{
|
||||
Config: Config{
|
||||
QueryTimeoutSec: 5, // Documented recommended timeout
|
||||
MaxRowLimit: 0, // Don't add LIMIT
|
||||
},
|
||||
}
|
||||
|
||||
// Caller creates timeout context (following Go best practices)
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(source.QueryTimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Apply query limits (doesn't modify context anymore)
|
||||
modifiedSQL, err := source.ApplyQueryLimits("SELECT * FROM users")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify context has deadline (managed by caller)
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
t.Error("Expected deadline to be set but it wasn't")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify deadline is approximately 5 seconds from now
|
||||
expectedDeadline := time.Now().Add(5 * time.Second)
|
||||
diff := deadline.Sub(expectedDeadline)
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
|
||||
// Allow 1 second tolerance
|
||||
if diff > time.Second {
|
||||
t.Errorf("Deadline diff too large: %v", diff)
|
||||
}
|
||||
|
||||
// Verify SQL is unchanged (LIMIT not added since MaxRowLimit=0)
|
||||
if modifiedSQL != "SELECT * FROM users" {
|
||||
t.Errorf("Expected SQL unchanged, got: %s", modifiedSQL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedactSQL tests SQL redaction for telemetry
|
||||
func TestRedactSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sql string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "String literal redaction",
|
||||
sql: "SELECT * FROM users WHERE name='alice' AND email='alice@example.com'",
|
||||
expected: "SELECT * FROM users WHERE name='***' AND email='***'",
|
||||
},
|
||||
{
|
||||
name: "Long number redaction",
|
||||
sql: "SELECT * FROM users WHERE ssn=1234567890123",
|
||||
expected: "SELECT * FROM users WHERE ssn=***",
|
||||
},
|
||||
{
|
||||
name: "Short numbers not redacted",
|
||||
sql: "SELECT * FROM users WHERE age=25",
|
||||
expected: "SELECT * FROM users WHERE age=25",
|
||||
},
|
||||
{
|
||||
name: "Multiple sensitive values",
|
||||
sql: "INSERT INTO users (name, email, phone) VALUES ('bob', 'bob@example.com', '5551234567')",
|
||||
expected: "INSERT INTO users (name, email, phone) VALUES ('***', '***', '***')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := RedactSQL(tt.sql)
|
||||
if result != tt.expected {
|
||||
t.Errorf("RedactSQL:\nGot: %s\nExpected: %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsReadOnlyMode tests read-only mode detection
|
||||
func TestIsReadOnlyMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
readOnlyMode bool
|
||||
enableWriteMode bool
|
||||
expected bool
|
||||
}{
|
||||
{"Read-only by default", true, false, true},
|
||||
{"Write mode enabled", false, true, false},
|
||||
{"Both false", false, false, false},
|
||||
{"Read-only overridden by write mode", true, true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
source := &Source{
|
||||
Config: Config{
|
||||
ReadOnlyMode: tt.readOnlyMode,
|
||||
EnableWriteMode: tt.enableWriteMode,
|
||||
},
|
||||
}
|
||||
|
||||
result := source.IsReadOnlyMode()
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsReadOnlyMode() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStructuredError tests error formatting
|
||||
func TestStructuredError(t *testing.T) {
|
||||
err := &StructuredError{
|
||||
Code: ErrCodeReadOnlyViolation,
|
||||
Message: "Write operations not allowed",
|
||||
Details: map[string]any{
|
||||
"sql_type": "INSERT",
|
||||
},
|
||||
}
|
||||
|
||||
errorStr := err.Error()
|
||||
if !strings.Contains(errorStr, ErrCodeReadOnlyViolation) {
|
||||
t.Errorf("Error string should contain error code: %s", errorStr)
|
||||
}
|
||||
if !strings.Contains(errorStr, "Write operations not allowed") {
|
||||
t.Errorf("Error string should contain message: %s", errorStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultSecuritySettings tests that security defaults are correct
|
||||
func TestDefaultSecuritySettings(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a minimal YAML config
|
||||
yamlData := `name: test
|
||||
type: cockroachdb
|
||||
host: localhost
|
||||
port: "26257"
|
||||
user: root
|
||||
database: defaultdb
|
||||
`
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil {
|
||||
t.Fatalf("Failed to unmarshal YAML: %v", err)
|
||||
}
|
||||
|
||||
// Apply defaults through newConfig logic manually
|
||||
cfg.MaxRetries = 5
|
||||
cfg.RetryBaseDelay = "500ms"
|
||||
cfg.ReadOnlyMode = true
|
||||
cfg.EnableWriteMode = false
|
||||
cfg.MaxRowLimit = 1000
|
||||
cfg.QueryTimeoutSec = 30
|
||||
cfg.EnableTelemetry = true
|
||||
cfg.TelemetryVerbose = false
|
||||
|
||||
_ = ctx // prevent unused
|
||||
|
||||
// Verify MCP security defaults
|
||||
if !cfg.ReadOnlyMode {
|
||||
t.Error("ReadOnlyMode should be true by default")
|
||||
}
|
||||
if cfg.EnableWriteMode {
|
||||
t.Error("EnableWriteMode should be false by default")
|
||||
}
|
||||
if cfg.MaxRowLimit != 1000 {
|
||||
t.Errorf("MaxRowLimit should be 1000, got %d", cfg.MaxRowLimit)
|
||||
}
|
||||
if cfg.QueryTimeoutSec != 30 {
|
||||
t.Errorf("QueryTimeoutSec should be 30, got %d", cfg.QueryTimeoutSec)
|
||||
}
|
||||
if !cfg.EnableTelemetry {
|
||||
t.Error("EnableTelemetry should be true by default")
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracle
|
||||
|
||||
import (
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbcreatecluster
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -122,44 +124,49 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
clusterID, ok := paramsMap["cluster"].(string)
|
||||
if !ok || clusterID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
password, ok := paramsMap["password"].(string)
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'password' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'password' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
network, ok := paramsMap["network"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'network' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'network' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
user, ok := paramsMap["user"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'user' parameter; expected a string", nil)
|
||||
}
|
||||
resp, err := source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken))
|
||||
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
return source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbcreateinstance
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -123,36 +125,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok || location == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'location' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok || cluster == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
instanceID, ok := paramsMap["instance"].(string)
|
||||
if !ok || instanceID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'instance' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
instanceType, ok := paramsMap["instanceType"].(string)
|
||||
if !ok || (instanceType != "READ_POOL" && instanceType != "PRIMARY") {
|
||||
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
|
||||
return nil, util.NewAgentError("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'", nil)
|
||||
}
|
||||
|
||||
displayName, _ := paramsMap["displayName"].(string)
|
||||
@@ -161,11 +163,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if instanceType == "READ_POOL" {
|
||||
nodeCount, ok = paramsMap["nodeCount"].(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL")
|
||||
return nil, util.NewAgentError("invalid 'nodeCount' parameter; expected an integer for READ_POOL", nil)
|
||||
}
|
||||
}
|
||||
|
||||
return source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken))
|
||||
resp, err := source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbcreateuser
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -122,43 +124,43 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok || location == "" {
|
||||
return nil, fmt.Errorf("invalid or missing'location' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing'location' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok || cluster == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
userID, ok := paramsMap["user"].(string)
|
||||
if !ok || userID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'user' parameter; expected a non-empty string")
|
||||
return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a non-empty string", nil)
|
||||
}
|
||||
|
||||
userType, ok := paramsMap["userType"].(string)
|
||||
if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") {
|
||||
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
|
||||
return nil, util.NewAgentError("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'", nil)
|
||||
}
|
||||
var password string
|
||||
|
||||
if userType == "ALLOYDB_BUILT_IN" {
|
||||
password, ok = paramsMap["password"].(string)
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("password is required when userType is ALLOYDB_BUILT_IN")
|
||||
return nil, util.NewAgentError("password is required when userType is ALLOYDB_BUILT_IN", nil)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +172,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
return source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID)
|
||||
resp, err := source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbgetcluster
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -120,28 +122,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
if !ok || location == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
if !ok || cluster == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.GetCluster(ctx, project, location, cluster, string(accessToken))
|
||||
resp, err := source.GetCluster(ctx, project, location, cluster, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbgetinstance
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
if !ok || location == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
if !ok || cluster == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||
}
|
||||
instance, ok := paramsMap["instance"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
|
||||
if !ok || instance == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'instance' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.GetInstance(ctx, project, location, cluster, instance, string(accessToken))
|
||||
resp, err := source.GetInstance(ctx, project, location, cluster, instance, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydbgetuser
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -120,32 +122,36 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
if !ok || location == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
if !ok || cluster == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||
}
|
||||
user, ok := paramsMap["user"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
if !ok || user == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'user' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.GetUsers(ctx, project, location, cluster, user, string(accessToken))
|
||||
resp, err := source.GetUsers(ctx, project, location, cluster, user, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydblistclusters
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -118,24 +120,28 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.ListCluster(ctx, project, location, string(accessToken))
|
||||
resp, err := source.ListCluster(ctx, project, location, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydblistinstances
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'location' parameter; expected a string", nil)
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
return nil, util.NewAgentError("invalid 'cluster' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.ListInstance(ctx, project, location, cluster, string(accessToken))
|
||||
resp, err := source.ListInstance(ctx, project, location, cluster, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package alloydblistusers
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -119,28 +121,32 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
if !ok || project == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'project' parameter; expected a string", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
if !ok || location == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'location' parameter; expected a string", nil)
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
if !ok || cluster == "" {
|
||||
return nil, util.NewAgentError("invalid or missing 'cluster' parameter; expected a string", nil)
|
||||
}
|
||||
|
||||
return source.ListUsers(ctx, project, location, cluster, string(accessToken))
|
||||
resp, err := source.ListUsers(ctx, project, location, cluster, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -213,25 +214,25 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
return nil, util.NewAgentError("missing 'project' parameter", nil)
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'location' parameter")
|
||||
return nil, util.NewAgentError("missing 'location' parameter", nil)
|
||||
}
|
||||
operation, ok := paramsMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
return nil, util.NewAgentError("missing 'operation' parameter", nil)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
@@ -246,14 +247,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
for retries < maxRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err())
|
||||
return nil, util.NewAgentError("timed out waiting for operation", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if op != nil {
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
if op != nil {
|
||||
return op, nil
|
||||
}
|
||||
|
||||
@@ -264,7 +266,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
retries++
|
||||
}
|
||||
return nil, fmt.Errorf("exceeded max retries waiting for operation")
|
||||
return nil, util.NewAgentError("exceeded max retries waiting for operation", nil)
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,14 @@ package alloydbainl
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
@@ -127,10 +129,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
sliceParams := params.AsSlice()
|
||||
@@ -143,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
resp, err := source.RunSQL(ctx, t.Statement, allParamValues)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
|
||||
return nil, util.NewClientServerError(fmt.Sprintf("error running SQL query: %v. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues), http.StatusBadRequest, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigqueryanalyzecontribution
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
@@ -154,21 +156,21 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
}
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast input_data parameter %s", paramsMap["input_data"]), nil)
|
||||
}
|
||||
|
||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
@@ -186,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"]), nil)
|
||||
}
|
||||
}
|
||||
if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok {
|
||||
@@ -195,7 +197,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if val, ok := paramsMap["pruning_method"].(string); ok {
|
||||
upperVal := strings.ToUpper(val)
|
||||
if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" {
|
||||
return nil, fmt.Errorf("invalid pruning_method: %s", val)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid pruning_method: %s", val), nil)
|
||||
}
|
||||
options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal))
|
||||
}
|
||||
@@ -207,7 +209,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||
}
|
||||
if session != nil {
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
@@ -216,22 +218,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("the 'input_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil)
|
||||
}
|
||||
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not analyze query in input_data to validate against allowed datasets")
|
||||
return nil, util.NewAgentError("could not analyze query in input_data to validate against allowed datasets", nil)
|
||||
}
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
@@ -245,10 +247,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
case 2: // dataset.table
|
||||
projectID, datasetID = source.BigQueryClient().Project(), parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData), nil)
|
||||
}
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData), nil)
|
||||
}
|
||||
}
|
||||
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
||||
@@ -268,7 +270,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Otherwise, a new session will be created by the first query.
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
if session != nil {
|
||||
@@ -281,15 +283,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
createModelJob, err := createModelQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start create model job: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
status, err := createModelJob.Wait(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for create model job: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("create model job failed: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
// Determine the session ID to use for subsequent queries.
|
||||
@@ -300,12 +302,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
} else if status.Statistics != nil && status.Statistics.SessionInfo != nil {
|
||||
sessionID = status.Statistics.SessionInfo.SessionID
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to get or create a BigQuery session ID")
|
||||
return nil, util.NewClientServerError("failed to get or create a BigQuery session ID", http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}}
|
||||
return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps)
|
||||
|
||||
resp, err := source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -35,6 +35,8 @@ import (
|
||||
|
||||
const resourceType string = "bigquery-conversational-analytics"
|
||||
|
||||
const gdaURLFormat = "https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat"
|
||||
|
||||
const instructions = `**INSTRUCTIONS - FOLLOW THESE RULES:**
|
||||
1. **CONTENT:** Your answer should present the supporting data and then provide a conclusion based on that data.
|
||||
2. **OUTPUT FORMAT:** Your entire response MUST be in plain text format ONLY.
|
||||
@@ -172,10 +174,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
@@ -184,26 +186,26 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if source.UseClientAuthorization() {
|
||||
// Use client-side access token
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized)
|
||||
return nil, util.NewClientServerError("tool is configured for client OAuth but no token was provided in the request header", http.StatusUnauthorized, nil)
|
||||
}
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||
}
|
||||
} else {
|
||||
// Get a token source for the Gemini Data Analytics API.
|
||||
tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token source: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get token source", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
// Use cloud-platform token source for Gemini Data Analytics API
|
||||
if tokenSource == nil {
|
||||
return nil, fmt.Errorf("cloud-platform token source is missing")
|
||||
return nil, util.NewClientServerError("cloud-platform token source is missing", http.StatusInternalServerError, nil)
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get token from cloud-platform token source", http.StatusInternalServerError, err)
|
||||
}
|
||||
tokenStr = token.AccessToken
|
||||
}
|
||||
@@ -218,14 +220,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var tableRefs []BQTableReference
|
||||
if tableRefsJSON != "" {
|
||||
if err := json.Unmarshal([]byte(tableRefsJSON), &tableRefs); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse 'table_references' JSON string: %w", err)
|
||||
return nil, util.NewAgentError("failed to parse 'table_references' JSON string", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
for _, tableRef := range tableRefs {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -236,11 +238,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if location == "" {
|
||||
location = "us"
|
||||
}
|
||||
caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1alpha/projects/%s/locations/%s:chat", projectID, location)
|
||||
caURL := fmt.Sprintf(gdaURLFormat, projectID, location)
|
||||
|
||||
headers := map[string]string{
|
||||
"Authorization": fmt.Sprintf("Bearer %s", tokenStr),
|
||||
"Content-Type": "application/json",
|
||||
"X-Goog-API-Client": util.GDAClientID,
|
||||
}
|
||||
|
||||
payload := CAPayload{
|
||||
@@ -252,13 +255,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
},
|
||||
Options: Options{Chart: ChartOptions{Image: ImageOptions{NoImage: map[string]any{}}}},
|
||||
},
|
||||
ClientIdEnum: "GENAI_TOOLBOX",
|
||||
ClientIdEnum: util.GDAClientID,
|
||||
}
|
||||
|
||||
// Call the streaming API
|
||||
response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
|
||||
// getStream wraps network errors or non-200 responses
|
||||
return nil, util.NewClientServerError("failed to get response from conversational analytics API", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -152,25 +153,25 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil)
|
||||
}
|
||||
dryRun, ok := paramsMap["dry_run"].(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast dry_run parameter %s", paramsMap["dry_run"]), nil)
|
||||
}
|
||||
|
||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
@@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected {
|
||||
session, err = source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session for protected mode", http.StatusInternalServerError, err)
|
||||
}
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: session.ID},
|
||||
@@ -187,7 +188,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
return nil, util.NewClientServerError("query validation failed", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
@@ -195,13 +196,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
switch source.BigQueryWriteMode() {
|
||||
case bigqueryds.WriteModeBlocked:
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed")
|
||||
return nil, util.NewAgentError("write mode is 'blocked', only SELECT statements are allowed", nil)
|
||||
}
|
||||
case bigqueryds.WriteModeProtected:
|
||||
if dryRunJob.Configuration != nil && dryRunJob.Configuration.Query != nil {
|
||||
if dest := dryRunJob.Configuration.Query.DestinationTable; dest != nil && dest.DatasetId != session.DatasetID {
|
||||
return nil, fmt.Errorf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
|
||||
"dataset of a BigQuery session, but destination was %q", dest.DatasetId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("protected write mode only supports SELECT statements, or write operations in the anonymous "+
|
||||
"dataset of a BigQuery session, but destination was %q", dest.DatasetId), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -209,11 +210,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
switch statementType {
|
||||
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
|
||||
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType), nil)
|
||||
case "CREATE_FUNCTION", "CREATE_TABLE_FUNCTION", "CREATE_PROCEDURE":
|
||||
return nil, fmt.Errorf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("creating stored routines ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil)
|
||||
case "CALL":
|
||||
return nil, fmt.Errorf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("calling stored procedures ('%s') is not allowed when dataset restrictions are in place, as their contents cannot be safely analyzed", statementType), nil)
|
||||
}
|
||||
|
||||
// Use a map to avoid duplicate table names.
|
||||
@@ -244,7 +245,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project())
|
||||
if parseErr != nil {
|
||||
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
|
||||
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
|
||||
return nil, util.NewAgentError("could not parse tables from query to validate against allowed datasets", parseErr)
|
||||
}
|
||||
tableNames = parsedTables
|
||||
}
|
||||
@@ -254,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if len(parts) == 3 {
|
||||
projectID, datasetID := parts[0], parts[1]
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID), nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -264,7 +265,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if dryRunJob != nil {
|
||||
jobJSON, err := json.MarshalIndent(dryRunJob, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal dry run job to JSON: %w", err)
|
||||
return nil, util.NewClientServerError("failed to marshal dry run job to JSON", http.StatusInternalServerError, err)
|
||||
}
|
||||
return string(jobJSON), nil
|
||||
}
|
||||
@@ -275,10 +276,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
||||
return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps)
|
||||
resp, err := source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, util.NewClientServerError("error running sql", http.StatusInternalServerError, err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigqueryforecast
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -133,34 +134,34 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
historyData, ok := paramsMap["history_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast history_data parameter %v", paramsMap["history_data"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast history_data parameter %v", paramsMap["history_data"]), nil)
|
||||
}
|
||||
timestampCol, ok := paramsMap["timestamp_col"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast timestamp_col parameter %v", paramsMap["timestamp_col"]), nil)
|
||||
}
|
||||
dataCol, ok := paramsMap["data_col"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast data_col parameter %v", paramsMap["data_col"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast data_col parameter %v", paramsMap["data_col"]), nil)
|
||||
}
|
||||
idColsRaw, ok := paramsMap["id_cols"].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast id_cols parameter %v", paramsMap["id_cols"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast id_cols parameter %v", paramsMap["id_cols"]), nil)
|
||||
}
|
||||
var idCols []string
|
||||
for _, v := range idColsRaw {
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("id_cols contains non-string value: %v", v)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("id_cols contains non-string value: %v", v), nil)
|
||||
}
|
||||
idCols = append(idCols, s)
|
||||
}
|
||||
@@ -169,13 +170,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if h, ok := paramsMap["horizon"].(float64); ok {
|
||||
horizon = int(h)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to cast horizon parameter %v", paramsMap["horizon"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast horizon parameter %v", paramsMap["horizon"]), nil)
|
||||
}
|
||||
}
|
||||
|
||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
var historyDataSource string
|
||||
@@ -185,7 +186,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||
}
|
||||
if session != nil {
|
||||
connProps = []*bigqueryapi.ConnectionProperty{
|
||||
@@ -194,22 +195,22 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
if statementType != "SELECT" {
|
||||
return nil, fmt.Errorf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("the 'history_data' parameter only supports a table ID or a SELECT query. The provided query has statement type '%s'", statementType), nil)
|
||||
}
|
||||
|
||||
queryStats := dryRunJob.Statistics.Query
|
||||
if queryStats != nil {
|
||||
for _, tableRef := range queryStats.ReferencedTables {
|
||||
if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
|
||||
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId), nil)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("could not analyze query in history_data to validate against allowed datasets")
|
||||
return nil, util.NewAgentError("could not analyze query in history_data to validate against allowed datasets", nil)
|
||||
}
|
||||
}
|
||||
historyDataSource = fmt.Sprintf("(%s)", historyData)
|
||||
@@ -226,11 +227,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
projectID = source.BigQueryClient().Project()
|
||||
datasetID = parts[0]
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData), nil)
|
||||
}
|
||||
|
||||
if !source.IsDatasetAllowed(projectID, datasetID) {
|
||||
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData), nil)
|
||||
}
|
||||
}
|
||||
historyDataSource = fmt.Sprintf("TABLE `%s`", historyData)
|
||||
@@ -251,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||
}
|
||||
var connProps []*bigqueryapi.ConnectionProperty
|
||||
if session != nil {
|
||||
@@ -264,11 +265,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Log the query executed for debugging.
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, sql))
|
||||
|
||||
return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps)
|
||||
resp, err := source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerygetdatasetinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
@@ -120,38 +122,38 @@ type Tool struct {
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
// Updated: Use fmt.Sprintf for formatting, pass nil as cause
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||
}
|
||||
|
||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||
}
|
||||
|
||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||
|
||||
metadata, err := dsHandle.Metadata(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, projectId, err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerygettableinfo
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
@@ -125,35 +127,35 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||
}
|
||||
|
||||
tableId, ok := mapParams[tableKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", tableKey), nil)
|
||||
}
|
||||
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||
}
|
||||
|
||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||
@@ -161,7 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
metadata, err := tableHandle.Metadata(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", projectId, datasetId, tableId, err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
|
||||
@@ -17,12 +17,14 @@ package bigquerylistdatasetids
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
@@ -120,10 +122,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
if len(source.BigQueryAllowedDatasets()) > 0 {
|
||||
@@ -132,12 +134,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||
}
|
||||
|
||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
datasetIterator := bqClient.Datasets(ctx)
|
||||
datasetIterator.ProjectID = projectId
|
||||
@@ -149,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to iterate through datasets: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
// Remove leading and trailing quotes
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerylisttableids
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
@@ -123,31 +125,30 @@ type Tool struct {
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", projectKey), nil)
|
||||
}
|
||||
|
||||
datasetId, ok := mapParams[datasetKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", datasetKey), nil)
|
||||
}
|
||||
|
||||
if !source.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId), nil)
|
||||
}
|
||||
|
||||
bqClient, _, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
|
||||
@@ -160,7 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", projectId, datasetId, err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
// Remove leading and trailing quotes
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerysearchcatalog
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
@@ -186,28 +188,31 @@ func ExtractType(resourceString string) string {
|
||||
return typeMap[resourceString[lastIndex+1:]]
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
prompt, _ := paramsMap["prompt"].(string)
|
||||
|
||||
projectIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["projectIds"].([]any), "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't convert projectIds to array of strings: %s", err)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("can't convert projectIds to array of strings: %s", err), err)
|
||||
}
|
||||
projectIds := projectIdSlice.([]string)
|
||||
|
||||
datasetIdSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["datasetIds"].([]any), "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't convert datasetIds to array of strings: %s", err)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("can't convert datasetIds to array of strings: %s", err), err)
|
||||
}
|
||||
datasetIds := datasetIdSlice.([]string)
|
||||
|
||||
typesSlice, err := parameters.ConvertAnySliceToTyped(paramsMap["types"].([]any), "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't convert types to array of strings: %s", err)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("can't convert types to array of strings: %s", err), err)
|
||||
}
|
||||
types := typesSlice.([]string)
|
||||
|
||||
@@ -223,17 +228,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||
}
|
||||
catalogClient, err = dataplexClientCreator(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
return nil, util.NewClientServerError("error creating client from OAuth access token", http.StatusInternalServerError, err)
|
||||
}
|
||||
}
|
||||
|
||||
it := catalogClient.SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject())
|
||||
return nil, util.NewClientServerError(fmt.Sprintf("failed to create search entries iterator for project %q", source.BigQueryProject()), http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
var results []Response
|
||||
@@ -243,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
entrySource := entry.DataplexEntry.GetEntrySource()
|
||||
resp := Response{
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerysql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
)
|
||||
@@ -103,11 +105,10 @@ type Tool struct {
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
@@ -116,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
return nil, util.NewAgentError("unable to extract template params", err)
|
||||
}
|
||||
|
||||
for _, p := range t.Parameters {
|
||||
@@ -127,13 +128,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if arrayParam, ok := p.(*parameters.ArrayParameter); ok {
|
||||
arrayParamValue, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` to []any", name), nil)
|
||||
}
|
||||
itemType := arrayParam.GetItems().GetType()
|
||||
var err error
|
||||
value, err = parameters.ConvertAnySliceToTyped(arrayParamValue, itemType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` from []any to typed slice: %w", name, err)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to convert parameter `%s` from []any to typed slice", name), err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,7 +162,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
lowLevelParam.ParameterType.Type = "ARRAY"
|
||||
itemType, err := bqutil.BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
|
||||
}
|
||||
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
|
||||
|
||||
@@ -178,7 +179,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// Handle scalar types based on their defined type.
|
||||
bqType, err := bqutil.BQTypeStringFromToolType(p.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewAgentError("unable to get BigQuery type from tool parameter type", err)
|
||||
}
|
||||
lowLevelParam.ParameterType.Type = bqType
|
||||
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
|
||||
@@ -190,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if source.BigQuerySession() != nil {
|
||||
session, err := source.BigQuerySession()(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get BigQuery session: %w", err)
|
||||
return nil, util.NewClientServerError("failed to get BigQuery session", http.StatusInternalServerError, err)
|
||||
}
|
||||
if session != nil {
|
||||
// Add session ID to the connection properties for subsequent calls.
|
||||
@@ -200,17 +201,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
bqClient, restService, err := source.RetrieveClientAndService(accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("failed to retrieve BigQuery client", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed: %w", err)
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
|
||||
resp, err := source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,14 @@ package bigtable
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"cloud.google.com/go/bigtable"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -96,24 +98,28 @@ type Tool struct {
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
return nil, util.NewAgentError("unable to extract template params", err)
|
||||
}
|
||||
|
||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||
}
|
||||
return source.RunSQL(ctx, newStatement, t.Parameters, newParams)
|
||||
|
||||
resp, err := source.RunSQL(ctx, newStatement, t.Parameters, newParams)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,14 @@ package cassandracql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
gocql "github.com/apache/cassandra-gocql-driver/v2"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -107,23 +109,27 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
}
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
return nil, util.NewAgentError("unable to extract template params", err)
|
||||
}
|
||||
|
||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||
}
|
||||
return source.RunSQL(ctx, newStatement, newParams)
|
||||
resp, err := source.RunSQL(ctx, newStatement, newParams)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Manifest implements tools.Tool.
|
||||
|
||||
@@ -17,11 +17,13 @@ package clickhouse
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -87,18 +89,22 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
sql, ok := paramsMap["sql"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"])
|
||||
return nil, util.NewAgentError(fmt.Sprintf("unable to cast sql parameter %s", paramsMap["sql"]), nil)
|
||||
}
|
||||
return source.RunSQL(ctx, sql, nil)
|
||||
resp, err := source.RunSQL(ctx, sql, nil)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package clickhouse
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -86,10 +88,10 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
// Query to list all databases
|
||||
@@ -97,7 +99,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
out, err := source.RunSQL(ctx, query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
|
||||
@@ -17,11 +17,13 @@ package clickhouse
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -90,34 +92,37 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
database, ok := mapParams[databaseKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", databaseKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", databaseKey), nil)
|
||||
}
|
||||
|
||||
// Query to list all tables in the specified database
|
||||
// Note: formatting identifier directly is risky if input is untrusted, but standard for this tool structure.
|
||||
query := fmt.Sprintf("SHOW TABLES FROM %s", database)
|
||||
|
||||
out, err := source.RunSQL(ctx, query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
|
||||
res, ok := out.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to convert result to list")
|
||||
return nil, util.NewClientServerError("unable to convert result to list", http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
var tables []map[string]any
|
||||
for _, item := range res {
|
||||
tableMap, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type in result: got %T, want map[string]any", item)
|
||||
return nil, util.NewClientServerError(fmt.Sprintf("unexpected type in result: got %T, want map[string]any", item), http.StatusInternalServerError, nil)
|
||||
}
|
||||
tableMap["database"] = database
|
||||
tables = append(tables, tableMap)
|
||||
|
||||
@@ -17,11 +17,13 @@ package clickhouse
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -88,24 +90,28 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params: %w", err)
|
||||
return nil, util.NewAgentError("unable to extract template params", err)
|
||||
}
|
||||
|
||||
newParams, err := parameters.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params: %w", err)
|
||||
return nil, util.NewAgentError("unable to extract standard params", err)
|
||||
}
|
||||
|
||||
return source.RunSQL(ctx, newStatement, newParams)
|
||||
resp, err := source.RunSQL(ctx, newStatement, newParams)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGeneralError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,11 +18,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -119,17 +121,16 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, ok := paramsMap["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
return nil, util.NewAgentError("query parameter not found or not a string", nil)
|
||||
}
|
||||
|
||||
// Parse the access token if provided
|
||||
@@ -138,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,9 +155,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||
return nil, util.NewClientServerError("failed to marshal request payload", http.StatusInternalServerError, err)
|
||||
}
|
||||
return source.RunQuery(ctx, tokenStr, bodyBytes)
|
||||
|
||||
resp, err := source.RunQuery(ctx, tokenStr, bodyBytes)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,11 +17,13 @@ package fhirfetchpage
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
@@ -93,24 +95,31 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
url, ok := params.AsMap()[pageURLKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", pageURLKey), nil)
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||
}
|
||||
}
|
||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
|
||||
resp, err := source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
if err != nil {
|
||||
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,6 +17,7 @@ package fhirpatienteverything
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
)
|
||||
@@ -116,26 +118,27 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// ValidateAndFetchStoreID usually returns input validation errors
|
||||
return nil, util.NewAgentError("failed to validate store ID", err)
|
||||
}
|
||||
patientID, ok := params.AsMap()[patientIDKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid or missing '%s' parameter; expected a string", patientIDKey), nil)
|
||||
}
|
||||
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
return nil, util.NewClientServerError("error parsing access token", http.StatusUnauthorized, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,11 +146,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string array", typeFilterKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string array", typeFilterKey), nil)
|
||||
}
|
||||
typeFilterSlice, err := parameters.ConvertAnySliceToTyped(types, "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't convert '%s' to array of strings: %s", typeFilterKey, err)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("can't convert '%s' to array of strings: %s", typeFilterKey, err), err)
|
||||
}
|
||||
if len(typeFilterSlice.([]string)) != 0 {
|
||||
opts = append(opts, googleapi.QueryParameter("_type", strings.Join(typeFilterSlice.([]string), ",")))
|
||||
@@ -156,13 +159,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if since, ok := params.AsMap()[sinceFilterKey]; ok {
|
||||
sinceStr, ok := since.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", sinceFilterKey)
|
||||
return nil, util.NewAgentError(fmt.Sprintf("invalid '%s' parameter; expected a string", sinceFilterKey), nil)
|
||||
}
|
||||
if sinceStr != "" {
|
||||
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
||||
}
|
||||
}
|
||||
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||
|
||||
resp, err := source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||
if err != nil {
|
||||
return nil, util.ProcessGcpError(err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user