diff --git a/.ci/continuous.release.cloudbuild.yaml b/.ci/continuous.release.cloudbuild.yaml index b73000aa1b..0025d46719 100644 --- a/.ci/continuous.release.cloudbuild.yaml +++ b/.ci/continuous.release.cloudbuild.yaml @@ -305,4 +305,4 @@ substitutions: _AR_HOSTNAME: ${_REGION}-docker.pkg.dev _AR_REPO_NAME: toolbox-dev _BUCKET_NAME: genai-toolbox-dev - _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox + _DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox \ No newline at end of file diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 77d7b1f352..14742514bc 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -212,6 +212,26 @@ steps: bigquery \ bigquery + - id: "cloud-gda" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "CLOUD_GDA_PROJECT=$PROJECT_ID" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Cloud Gemini Data Analytics" \ + cloudgda \ + cloudgda + - id: "dataplex" name: golang:1 waitFor: ["compile-test-binary"] @@ -273,7 +293,7 @@ steps: .ci/test_with_coverage.sh \ "Cloud Healthcare API" \ cloudhealthcare \ - cloudhealthcare + cloudhealthcare || echo "Integration tests failed." - id: "postgres" name: golang:1 @@ -318,7 +338,7 @@ steps: .ci/test_with_coverage.sh \ "Spanner" \ spanner \ - spanner + spanner || echo "Integration tests failed." # ignore test failures - id: "neo4j" name: golang:1 @@ -385,7 +405,7 @@ steps: .ci/test_with_coverage.sh \ "Cloud SQL MySQL" \ cloudsqlmysql \ - mysql || echo "Integration tests failed." # ignore test failures + mysql - id: "mysql" name: golang:1 @@ -407,7 +427,7 @@ steps: .ci/test_with_coverage.sh \ "MySQL" \ mysql \ - mysql || echo "Integration tests failed." # ignore test failures + mysql - id: "mssql" name: golang:1 @@ -589,6 +609,26 @@ steps: firestore \ firestore + - id: "mongodb" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "MONGODB_DATABASE=$_DATABASE_NAME" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["MONGODB_URI", "CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "MongoDB" \ + mongodb \ + mongodb + - id: "looker" name: golang:1 waitFor: ["compile-test-binary"] @@ -806,8 +846,8 @@ steps: cassandra - id: "oracle" - name: golang:1 - waitFor: ["compile-test-binary"] + name: ghcr.io/oracle/oraclelinux9-instantclient:23 + waitFor: ["install-dependencies"] entrypoint: /bin/bash env: - "GOPATH=/gopath" @@ -820,10 +860,25 @@ steps: args: - -c - | - .ci/test_with_coverage.sh \ - "Oracle" \ - oracle \ - oracle + # Install the C compiler and Oracle SDK headers needed for cgo + dnf install -y gcc oracle-instantclient-devel + # Install Go + curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz" + tar -C /usr/local -xzf go.tar.gz + export PATH="/usr/local/go/bin:$$PATH" + + go test -v ./internal/sources/oracle/... \ + -coverprofile=oracle_coverage.out \ + -coverpkg=./internal/sources/oracle/...,./internal/tools/oracle/... + + # Coverage check + total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}') + echo "Oracle total coverage: $total_coverage" + coverage_numeric=$(echo "$total_coverage" | sed 's/%//') + if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then + echo "Coverage failure: $total_coverage is below 30%." + exit 1 + fi - id: "serverless-spark" name: golang:1 @@ -867,6 +922,26 @@ steps: singlestore \ singlestore + - id: "mariadb" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "MARIADB_DATABASE=$_MARIADB_DATABASE" + - "MARIADB_PORT=$_MARIADB_PORT" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["MARIADB_USER", "MARIADB_PASS", "MARIADB_HOST", "CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + # skip coverage check as it re-uses current MySQL implementation + go test ./tests/mariadb + + availableSecrets: secretManager: - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest @@ -979,6 +1054,14 @@ availableSecrets: env: SINGLESTORE_PASSWORD - versionName: projects/$PROJECT_ID/secrets/singlestore_host/versions/latest env: SINGLESTORE_HOST + - versionName: projects/$PROJECT_ID/secrets/mariadb_user/versions/latest + env: MARIADB_USER + - versionName: projects/$PROJECT_ID/secrets/mariadb_pass/versions/latest + env: MARIADB_PASS + - versionName: projects/$PROJECT_ID/secrets/mariadb_host/versions/latest + env: MARIADB_HOST + - versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest + env: MONGODB_URI options: logging: CLOUD_LOGGING_ONLY @@ -1039,3 +1122,6 @@ substitutions: _SINGLESTORE_PORT: "3308" _SINGLESTORE_DATABASE: "singlestore" _SINGLESTORE_USER: "root" + _MARIADB_PORT: "3307" + _MARIADB_DATABASE: test_database + diff --git a/.ci/quickstart_test/js.integration.cloudbuild.yaml b/.ci/quickstart_test/js.integration.cloudbuild.yaml index 885f236f75..cbf4e8547f 100644 --- a/.ci/quickstart_test/js.integration.cloudbuild.yaml +++ b/.ci/quickstart_test/js.integration.cloudbuild.yaml @@ -13,7 +13,7 @@ # limitations under the License. steps: - - name: 'node:20' + - name: 'node:22' id: 'js-quickstart-test' entrypoint: 'bash' args: @@ -44,4 +44,4 @@ availableSecrets: timeout: 1000s options: - logging: CLOUD_LOGGING_ONLY \ No newline at end of file + logging: CLOUD_LOGGING_ONLY diff --git a/.gemini/config.yaml b/.gemini/config.yaml new file mode 100644 index 0000000000..4631b158a1 --- /dev/null +++ b/.gemini/config.yaml @@ -0,0 +1,18 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ignore_patterns: + - "package-lock.json" + - "go.sum" + - "requirements.txt" \ No newline at end of file diff --git a/.github/renovate.json5 b/.github/renovate.json5 index 042ea65777..0fb0447b02 100644 --- a/.github/renovate.json5 +++ b/.github/renovate.json5 @@ -24,5 +24,23 @@ ], pinDigests: true, }, + { + groupName: 'Go', + matchManagers: [ + 'gomod', + ], + }, + { + groupName: 'Node', + matchManagers: [ + 'npm', + ], + }, + { + groupName: 'Pip', + matchManagers: [ + 'pip_requirements', + ], + }, ], } diff --git a/.github/workflows/deploy_dev_docs.yaml b/.github/workflows/deploy_dev_docs.yaml index 0eee9a4330..1f4eac99e7 100644 --- a/.github/workflows/deploy_dev_docs.yaml +++ b/.github/workflows/deploy_dev_docs.yaml @@ -40,7 +40,7 @@ jobs: group: docs-deployment cancel-in-progress: false steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod @@ -56,7 +56,7 @@ jobs: node-version: "22" - name: Cache dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} diff --git a/.github/workflows/deploy_previous_version_docs.yaml b/.github/workflows/deploy_previous_version_docs.yaml index b11bf13138..88774eab4c 100644 --- a/.github/workflows/deploy_previous_version_docs.yaml +++ b/.github/workflows/deploy_previous_version_docs.yaml @@ -30,14 +30,14 @@ jobs: steps: - name: Checkout main branch (for latest templates and theme) - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: 'main' submodules: 'recursive' fetch-depth: 0 - name: Checkout old content from tag into a temporary directory - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: ${{ github.event.inputs.version_tag }} path: 'old_version_source' # Checkout into a temp subdir diff --git a/.github/workflows/deploy_versioned_docs.yaml b/.github/workflows/deploy_versioned_docs.yaml index 5c23b51994..47ff7583e0 100644 --- a/.github/workflows/deploy_versioned_docs.yaml +++ b/.github/workflows/deploy_versioned_docs.yaml @@ -30,7 +30,7 @@ jobs: cancel-in-progress: false steps: - name: Checkout Code at Tag - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: ${{ github.event.release.tag_name }} diff --git a/.github/workflows/docs_preview_clean.yaml b/.github/workflows/docs_preview_clean.yaml index 5dc6070aa7..ba44bfcc8b 100644 --- a/.github/workflows/docs_preview_clean.yaml +++ b/.github/workflows/docs_preview_clean.yaml @@ -34,7 +34,7 @@ jobs: group: "preview-${{ github.event.number }}" cancel-in-progress: true steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: versioned-gh-pages diff --git a/.github/workflows/docs_preview_deploy.yaml b/.github/workflows/docs_preview_deploy.yaml index 1e72e69a30..4c554dc7b4 100644 --- a/.github/workflows/docs_preview_deploy.yaml +++ b/.github/workflows/docs_preview_deploy.yaml @@ -49,7 +49,7 @@ jobs: group: "preview-${{ github.event.number }}" cancel-in-progress: true steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: # Checkout the PR's HEAD commit (supports forks). ref: ${{ github.event.pull_request.head.sha }} @@ -67,7 +67,7 @@ jobs: node-version: "22" - name: Cache dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} diff --git a/.github/workflows/link_checker_workflow.yaml b/.github/workflows/link_checker_workflow.yaml new file mode 100644 index 0000000000..e7863c080f --- /dev/null +++ b/.github/workflows/link_checker_workflow.yaml @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +name: Link Checker + +on: + pull_request: + + +jobs: + link-check: + runs-on: ubuntu-latest + steps: + - name: Checkout Repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 + + - name: Restore lychee cache + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 + with: + path: .lycheecache + key: cache-lychee-${{ github.sha }} + restore-keys: cache-lychee- + + - name: Link Checker + uses: lycheeverse/lychee-action@v2 + with: + args: > + --verbose + --no-progress + --cache + --max-cache-age 1d + README.md + docs/ + output: /tmp/foo.txt + fail: true + jobSummary: true + debug: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # This step only runs if the 'lychee_check' step fails, ensuring the + # context note only appears when the developer needs to troubleshoot. + - name: Display Link Context Note on Failure + if: ${{ failure() }} + run: | + echo "## Link Resolution Note" >> $GITHUB_STEP_SUMMARY + echo "Local links and directory changes work differently on GitHub than on the docsite." >> $GITHUB_STEP_SUMMARY + echo "You must ensure fixes pass the **GitHub check** and also work with **\`hugo server\`**." >> $GITHUB_STEP_SUMMARY + echo "---" >> $GITHUB_STEP_SUMMARY + diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 59b32d432d..0870637648 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -55,7 +55,7 @@ jobs: with: go-version: "1.25" - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{ github.event.pull_request.head.repo.full_name }} @@ -66,7 +66,7 @@ jobs: run: | go mod tidy && git diff --exit-code - name: golangci-lint - uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 + uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9.2.0 with: version: latest args: --timeout 10m diff --git a/.github/workflows/publish-mcp.yml b/.github/workflows/publish-mcp.yml index 34d29a0960..dc84fbb759 100644 --- a/.github/workflows/publish-mcp.yml +++ b/.github/workflows/publish-mcp.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - name: Wait for image in Artifact Registry shell: bash diff --git a/.github/workflows/sync-labels.yaml b/.github/workflows/sync-labels.yaml index ef6842fcb2..2a0d392497 100644 --- a/.github/workflows/sync-labels.yaml +++ b/.github/workflows/sync-labels.yaml @@ -29,7 +29,7 @@ jobs: issues: 'write' pull-requests: 'write' steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - uses: micnncim/action-label-syncer@3abd5ab72fda571e69fffd97bd4e0033dd5f495c # v1.3.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a225afa266..c11f7f388c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -62,7 +62,7 @@ jobs: go-version: "1.24" - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{ github.event.pull_request.head.repo.full_name }} diff --git a/.hugo/assets/scss/_styles_project.scss b/.hugo/assets/scss/_styles_project.scss index 9b665b0f6d..34a7e05c22 100644 --- a/.hugo/assets/scss/_styles_project.scss +++ b/.hugo/assets/scss/_styles_project.scss @@ -1 +1,9 @@ -@import 'td/code-dark'; \ No newline at end of file +@import 'td/code-dark'; + +// Make tabs scrollable horizontally instead of wrapping +.nav-tabs { + flex-wrap: nowrap; + white-space: nowrap; + overflow-x: auto; + overflow-y: hidden; +} \ No newline at end of file diff --git a/.hugo/hugo.toml b/.hugo/hugo.toml index 647d676487..27c2945a6e 100644 --- a/.hugo/hugo.toml +++ b/.hugo/hugo.toml @@ -51,6 +51,18 @@ ignoreFiles = ["quickstart/shared", "quickstart/python", "quickstart/js", "quick # Add a new version block here before every release # The order of versions in this file is mirrored into the dropdown +[[params.versions]] + version = "v0.24.0" + url = "https://googleapis.github.io/genai-toolbox/v0.24.0/" + +[[params.versions]] + version = "v0.23.0" + url = "https://googleapis.github.io/genai-toolbox/v0.23.0/" + +[[params.versions]] + version = "v0.22.0" + url = "https://googleapis.github.io/genai-toolbox/v0.22.0/" + [[params.versions]] version = "v0.21.0" url = "https://googleapis.github.io/genai-toolbox/v0.21.0/" diff --git a/.lycheeignore b/.lycheeignore new file mode 100644 index 0000000000..1146561589 --- /dev/null +++ b/.lycheeignore @@ -0,0 +1,45 @@ +# Ignore documentation placeholders and generic example domains +^https?://([a-zA-Z0-9-]+\.)?example\.com(:\d+)?(/.*)?$ +^http://example\.net + +# Shields.io badges often trigger rate limits or intermittent 503s +^https://img\.shields\.io/.* + +# PDF files are ignored as lychee cannot reliably parse internal PDF links +\.pdf$ + +# Standard mailto: protocol is not a web URL +^mailto: + +# Ignore local development endpoints that won't resolve in CI/CD environments +^https?://(127\.0\.0\.1|localhost)(:\d+)?(/.*)?$ + +# Placeholder for Google Cloud Run service discovery +https://cloud-run-url.app/ + +# DGraph Cloud and private instance endpoints +https://xxx.cloud.dgraph.io/ +https://cloud.dgraph.io/login +https://dgraph.io/docs + +# MySQL Community downloads and main site (often protected by bot mitigation) +https://dev.mysql.com/downloads/installer/ +https://www.mysql.com/ + +# Claude desktop download link +https://claude.ai/download + +# Google Cloud Run product page +https://cloud.google.com/run + +# These specific deep links are known to cause redirect loops or 403s in automated scrapers +https://dev.mysql.com/doc/refman/8.4/en/sql-prepared-statements.html +https://dev.mysql.com/doc/refman/8.4/en/user-names.html + +# npmjs links can occasionally trigger rate limiting during high-frequency CI builds +https://www.npmjs.com/package/@toolbox-sdk/core +https://www.npmjs.com/package/@toolbox-sdk/adk + + +# Ignore social media and blog profiles to reduce external request overhead +https://medium.com/@mcp_toolbox \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index f67fd0345b..c4fccb78d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,84 @@ # Changelog +## [0.24.0](https://github.com/googleapis/genai-toolbox/compare/v0.23.0...v0.24.0) (2025-12-19) + + +### Features + +* **sources/cloud-gemini-data-analytics:** Add the Gemini Data Analytics (GDA) integration for DB NL2SQL conversion to Toolbox ([#2181](https://github.com/googleapis/genai-toolbox/issues/2181)) ([aa270b2](https://github.com/googleapis/genai-toolbox/commit/aa270b2630da2e3d618db804ca95550445367dbc)) +* **source/cloudsqlmysql:** Add support for IAM authentication in Cloud SQL MySQL source ([#2050](https://github.com/googleapis/genai-toolbox/issues/2050)) ([af3d3c5](https://github.com/googleapis/genai-toolbox/commit/af3d3c52044bea17781b89ce4ab71ff0f874ac20)) +* **sources/oracle:** Add Oracle OCI and Wallet support ([#1945](https://github.com/googleapis/genai-toolbox/issues/1945)) ([8ea39ec](https://github.com/googleapis/genai-toolbox/commit/8ea39ec32fbbaa97939c626fec8c5d86040ed464)) +* Support combining prebuilt and custom tool configurations ([#2188](https://github.com/googleapis/genai-toolbox/issues/2188)) ([5788605](https://github.com/googleapis/genai-toolbox/commit/57886058188aa5d2a51d5846a98bc6d8a650edd1)) +* **tools/mysql-get-query-plan:** Add new `mysql-get-query-plan` tool for MySQL source ([#2123](https://github.com/googleapis/genai-toolbox/issues/2123)) ([0641da0](https://github.com/googleapis/genai-toolbox/commit/0641da0353857317113b2169e547ca69603ddfde)) + + +### Bug Fixes + +* **spanner:** Move list graphs validation to runtime ([#2154](https://github.com/googleapis/genai-toolbox/issues/2154)) ([914b3ee](https://github.com/googleapis/genai-toolbox/commit/914b3eefda40a650efe552d245369e007277dab5)) + + +## [0.23.0](https://github.com/googleapis/genai-toolbox/compare/v0.22.0...v0.23.0) (2025-12-11) + + +### ⚠ BREAKING CHANGES + +* **serverless-spark:** add URLs to create batch tool outputs +* **serverless-spark:** add URLs to list_batches output +* **serverless-spark:** add Cloud Console and Logging URLs to get_batch +* **tools/postgres:** Add additional filter params for existing postgres tools ([#2033](https://github.com/googleapis/genai-toolbox/issues/2033)) + +### Features + +* **tools/postgres:** Add list-table-stats-tool to list table statistics. ([#2055](https://github.com/googleapis/genai-toolbox/issues/2055)) ([78b02f0](https://github.com/googleapis/genai-toolbox/commit/78b02f08c3cc3062943bb2f91cf60d5149c8d28d)) +* **looker/tools:** Enhance dashboard creation with dashboard filters ([#2133](https://github.com/googleapis/genai-toolbox/issues/2133)) ([285aa46](https://github.com/googleapis/genai-toolbox/commit/285aa46b887d9acb2da8766e107bbf1ab75b8812)) +* **serverless-spark:** Add Cloud Console and Logging URLs to get_batch ([e29c061](https://github.com/googleapis/genai-toolbox/commit/e29c0616d6b9ecda2badcaf7b69614e511ac031b)) +* **serverless-spark:** Add URLs to create batch tool outputs ([c6ccf4b](https://github.com/googleapis/genai-toolbox/commit/c6ccf4bd87026484143a2d0f5527b2edab03b54a)) +* **serverless-spark:** Add URLs to list_batches output ([5605eab](https://github.com/googleapis/genai-toolbox/commit/5605eabd696696ade07f52431a28ef65c0fb1f77)) +* **sources/mariadb:** Add MariaDB source and MySQL tools integration ([#1908](https://github.com/googleapis/genai-toolbox/issues/1908)) ([3b40fea](https://github.com/googleapis/genai-toolbox/commit/3b40fea25edae607e02c1e8fc2b0c957fa2c8e9a)) +* **tools/postgres:** Add additional filter params for existing postgres tools ([#2033](https://github.com/googleapis/genai-toolbox/issues/2033)) ([489117d](https://github.com/googleapis/genai-toolbox/commit/489117d74711ac9260e7547163ca463eb45eeaa2)) +* **tools/postgres:** Add list_pg_settings, list_database_stats tools for postgres ([#2030](https://github.com/googleapis/genai-toolbox/issues/2030)) ([32367a4](https://github.com/googleapis/genai-toolbox/commit/32367a472fae9653fed7f126428eba0252978bd5)) +* **tools/postgres:** Add new postgres-list-roles tool ([#2038](https://github.com/googleapis/genai-toolbox/issues/2038)) ([bea9705](https://github.com/googleapis/genai-toolbox/commit/bea97054502cfa236aa10e2ebc8ff58eb00ad035)) + + +### Bug Fixes + +* List tables tools null fix ([#2107](https://github.com/googleapis/genai-toolbox/issues/2107)) ([2b45266](https://github.com/googleapis/genai-toolbox/commit/2b452665983154041d4cd0ed7d82532e4af682eb)) +* **tools/mongodb:** Removed sortPayload and sortParams ([#1238](https://github.com/googleapis/genai-toolbox/issues/1238)) ([c5a6daa](https://github.com/googleapis/genai-toolbox/commit/c5a6daa7683d2f9be654300d977692c368e55e31)) + + +### Miscellaneous Chores +* **looker:** Upgrade to latest go sdk ([#2159](https://github.com/googleapis/genai-toolbox/issues/2159)) ([78e015d](https://github.com/googleapis/genai-toolbox/commit/78e015d7dfd9cce7e2b444ed934da17eb355bc86)) + +## [0.22.0](https://github.com/googleapis/genai-toolbox/compare/v0.21.0...v0.22.0) (2025-12-04) + + +### Features + +* **tools/postgres:** Add allowed-origins flag ([#1984](https://github.com/googleapis/genai-toolbox/issues/1984)) ([862868f](https://github.com/googleapis/genai-toolbox/commit/862868f28476ea981575ce412faa7d6a03138f31)) +* **tools/postgres:** Add list-query-stats and get-column-cardinality functions ([#1976](https://github.com/googleapis/genai-toolbox/issues/1976)) ([9f76026](https://github.com/googleapis/genai-toolbox/commit/9f760269253a8cc92a357e995c6993ccc4a0fb7b)) +* **tools/spanner:** Add spanner list graphs to prebuiltconfigs ([#2056](https://github.com/googleapis/genai-toolbox/issues/2056)) ([0e7fbf4](https://github.com/googleapis/genai-toolbox/commit/0e7fbf465c488397aa9d8cab2e55165fff4eb53c)) +* **prebuilt/cloud-sql:** Add clone instance tool for cloud sql ([#1845](https://github.com/googleapis/genai-toolbox/issues/1845)) ([5e43630](https://github.com/googleapis/genai-toolbox/commit/5e43630907aa2d7bc6818142483a33272eab060b)) +* **serverless-spark:** Add create_pyspark_batch tool ([1bf0b51](https://github.com/googleapis/genai-toolbox/commit/1bf0b51f033c956790be1577bf5310d0b17e9c12)) +* **serverless-spark:** Add create_spark_batch tool ([17a9792](https://github.com/googleapis/genai-toolbox/commit/17a979207dbc4fe70acd0ebda164d1a8d34c1ed3)) +* Support alternate accessToken header name ([#1968](https://github.com/googleapis/genai-toolbox/issues/1968)) ([18017d6](https://github.com/googleapis/genai-toolbox/commit/18017d6545335a6fc1c472617101c35254d9a597)) +* Support for annotations ([#2007](https://github.com/googleapis/genai-toolbox/issues/2007)) ([ac21335](https://github.com/googleapis/genai-toolbox/commit/ac21335f4e88ca52d954d7f8143a551a35661b94)) +* **tool/mssql:** Set default host and port for MSSQL source ([#1943](https://github.com/googleapis/genai-toolbox/issues/1943)) ([7a9cc63](https://github.com/googleapis/genai-toolbox/commit/7a9cc633768d9ae9a7ff8230002da69d6a36ca86)) +* **tools/cloudsqlpg:** Add CloudSQL PostgreSQL pre-check tool ([#1722](https://github.com/googleapis/genai-toolbox/issues/1722)) ([8752e05](https://github.com/googleapis/genai-toolbox/commit/8752e05ab6e98812d95673a6f1ff67e9a6ae48d2)) +* **tools/postgres-list-publication-tables:** Add new postgres-list-publication-tables tool ([#1919](https://github.com/googleapis/genai-toolbox/issues/1919)) ([f4b1f0a](https://github.com/googleapis/genai-toolbox/commit/f4b1f0a68000ca2fc0325f55a1905705417c38a2)) +* **tools/postgres-list-tablespaces:** Add new postgres-list-tablespaces tool ([#1934](https://github.com/googleapis/genai-toolbox/issues/1934)) ([5ad7c61](https://github.com/googleapis/genai-toolbox/commit/5ad7c6127b3e47504fc4afda0b7f3de1dff78b8b)) +* **tools/spanner-list-graph:** Tool impl + docs + tests ([#1923](https://github.com/googleapis/genai-toolbox/issues/1923)) ([a0f44d3](https://github.com/googleapis/genai-toolbox/commit/a0f44d34ea3f044dd08501be616f70ddfd63ab45)) + + +### Bug Fixes + +* Add import for firebirdsql ([#2045](https://github.com/googleapis/genai-toolbox/issues/2045)) ([fb7aae9](https://github.com/googleapis/genai-toolbox/commit/fb7aae9d35b760d3471d8379642f835a0d84ec41)) +* Correct FAQ to mention HTTP tools ([#2036](https://github.com/googleapis/genai-toolbox/issues/2036)) ([7b44237](https://github.com/googleapis/genai-toolbox/commit/7b44237d4a21bfbf8d3cebe4d32a15affa29584d)) +* Format BigQuery numeric output as decimal strings ([#2084](https://github.com/googleapis/genai-toolbox/issues/2084)) ([155bff8](https://github.com/googleapis/genai-toolbox/commit/155bff80c1da4fae1e169e425fd82e1dc3373041)) +* Set default annotations for tools in code if annotation not provided in yaml ([#2049](https://github.com/googleapis/genai-toolbox/issues/2049)) ([565460c](https://github.com/googleapis/genai-toolbox/commit/565460c4ea8953dbe80070a8e469f957c0f7a70c)) +* **tools/alloydb-postgres-list-tables:** Exclude google_ml schema from list_tables ([#2046](https://github.com/googleapis/genai-toolbox/issues/2046)) ([a03984c](https://github.com/googleapis/genai-toolbox/commit/a03984cc15254c928f30085f8fa509ded6a79a0c)) +* **tools/alloydbcreateuser:** Remove duplication of project praram ([#2028](https://github.com/googleapis/genai-toolbox/issues/2028)) ([730ac6d](https://github.com/googleapis/genai-toolbox/commit/730ac6d22805fd50b4a675b74c1865f4e7689e7c)) +* **tools/mongodb:** Remove `required` tag from the `canonical` field ([#2099](https://github.com/googleapis/genai-toolbox/issues/2099)) ([744214e](https://github.com/googleapis/genai-toolbox/commit/744214e04cd12b11d166e6eb7da8ce4714904abc)) + ## [0.21.0](https://github.com/googleapis/genai-toolbox/compare/v0.20.0...v0.21.0) (2025-11-19) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bf7dc9abdb..5e7b8122a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -167,15 +167,15 @@ tools. [integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml). [tool-get]: - https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L31 + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L41 [tool-call]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L229 [mcp-call]: - https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L554 + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L789 [execute-sql]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L609 [temp-param]: - + https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L454 [temp-param-doc]: https://googleapis.github.io/genai-toolbox/resources/tools/#template-parameters diff --git a/DEVELOPER.md b/DEVELOPER.md index db20811990..bd8c49913e 100644 --- a/DEVELOPER.md +++ b/DEVELOPER.md @@ -109,7 +109,7 @@ golangci-lint run --fix Execute unit tests locally: ```bash -go test -race -v ./... +go test -race -v ./cmd/... ./internal/... ``` ### Integration Tests @@ -207,6 +207,30 @@ variables for each source. * SQLite - setup in the integration test, where we create a temporary database file +### Link Checking and Fixing with Lychee + +We use **[lychee](https://github.com/lycheeverse/lychee-action)** for repository link checks. + +* To run the checker **locally**, see the [command-line usage guide](https://github.com/lycheeverse/lychee?tab=readme-ov-file#commandline-usage). + +#### Fixing Broken Links + +1. **Update the Link:** Correct the broken URL or update the content where it is used. +2. **Ignore the Link:** If you can't fix the link (e.g., due to **external rate-limits** or if it's a **local-only URL**), tell Lychee to **ignore** it. + + * List **regular expressions** or **direct links** in the **[.lycheeignore](https://github.com/googleapis/genai-toolbox/blob/main/.lycheeignore)** file, one entry per line. + * **Always add a comment** explaining **why** the link is being skipped to prevent link rot. **Example `.lycheeignore`:** + ```text + # These are email addresses, not standard web URLs, and usually cause check failures. + ^mailto:.* + ``` +> [!NOTE] +> To avoid build failures in GitHub Actions, follow the linking pattern demonstrated here:
+> **Avoid:** (Works in Hugo, breaks Link Checker): `[Read more](docs/setup)` or `[Read more](docs/setup/)`
+> **Reason:** The link checker cannot find a file named "setup" or a directory with that name containing an index.
+> **Preferred:** `[Read more](docs/setup.md)`
+> **Reason:** The GitHub Action finds the physical file. Hugo then uses its internal logic (or render hooks) to resolve this to the correct `/docs/setup/` web URL.
+ ### Other GitHub Checks * License header check (`.github/header-checker-lint.yml`) - Ensures files have @@ -280,6 +304,7 @@ There are 3 GHA workflows we use to achieve document versioning: Request a repo owner to run the preview deployment workflow on your PR. A preview link will be automatically added as a comment to your PR. + #### Maintainers 1. **Inspect Changes:** Review the proposed changes in the PR to ensure they are diff --git a/README.md b/README.md index f9de14223a..13d08558d7 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,21 @@ redeploying your application. ## Getting Started +### (Non-production) Running Toolbox + +You can run Toolbox directly with a [configuration file](#configuration): + +```sh +npx @toolbox-sdk/server --tools-file tools.yaml +``` + +This runs the latest version of the toolbox server with your configuration file. + +> [!NOTE] +> This method should only be used for non-production use cases such as +> experimentation. For any production use-cases, please consider [Installing the +> server](#installing-the-server) and then [running it](#running-the-server). + ### Installing the server For the latest version, check the [releases page][releases] and use the @@ -125,7 +140,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.21.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox > chmod +x toolbox > ``` @@ -138,7 +153,7 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.21.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox > chmod +x toolbox > ``` @@ -151,21 +166,33 @@ To install Toolbox as a binary: > > ```sh > # see releases page for other versions -> export VERSION=0.21.0 +> export VERSION=0.24.0 > curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox > chmod +x toolbox > ``` > > >
-> Windows (AMD64) +> Windows (Command Prompt) > -> To install Toolbox as a binary on Windows (AMD64): +> To install Toolbox as a binary on Windows (Command Prompt): +> +> ```cmd +> :: see releases page for other versions +> set VERSION=0.24.0 +> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" +> ``` +> +>
+>
+> Windows (PowerShell) +> +> To install Toolbox as a binary on Windows (PowerShell): > > ```powershell -> :: see releases page for other versions -> set VERSION=0.21.0 -> curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" +> # see releases page for other versions +> $VERSION = "0.24.0" +> curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" > ``` > >
@@ -177,7 +204,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.21.0 +export VERSION=0.24.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -201,7 +228,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.21.0 +go install github.com/googleapis/genai-toolbox@v0.24.0 ``` @@ -291,6 +318,16 @@ toolbox --tools-file "tools.yaml" +
+NPM + +To run Toolbox directly without manually downloading the binary (requires Node.js): +```sh +npx @toolbox-sdk/server --tools-file tools.yaml +``` + +
+
Gemini CLI @@ -515,6 +552,36 @@ For more detailed instructions on using the Toolbox Core SDK, see the ```
+
+ ADK + +1. Install [Toolbox ADK SDK][toolbox-adk-js]: + + ```bash + npm install @toolbox-sdk/adk + ``` + +2. Load tools: + + ```javascript + import { ToolboxClient } from '@toolbox-sdk/adk'; + + // update the url to point to your server + const URL = 'http://127.0.0.1:5000'; + let client = new ToolboxClient(URL); + + // these tools can be passed to your application! + const tools = await client.loadToolset('toolsetName'); + ``` + + For more detailed instructions on using the Toolbox ADK SDK, see the + [project's README][toolbox-adk-js-readme]. + + [toolbox-adk-js]: https://www.npmjs.com/package/@toolbox-sdk/adk + [toolbox-adk-js-readme]: + https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-adk/README.md + +
@@ -968,12 +1035,12 @@ The version will be incremented as follows: ### Post-1.0.0 Versioning -Once the project reaches a stable `1.0.0` release, the versioning will follow -the more common convention: +Once the project reaches a stable `1.0.0` release, the version number +**`MAJOR.MINOR.PATCH`** will follow the more common convention: -- **`MAJOR.MINOR.PATCH`**: Incremented for incompatible API changes. -- **`MAJOR.MINOR.PATCH`**: Incremented for new, backward-compatible functionality. -- **`MAJOR.MINOR.PATCH`**: Incremented for backward-compatible bug fixes. +- **`MAJOR`**: Incremented for incompatible API changes. +- **`MINOR`**: Incremented for new, backward-compatible functionality. +- **`PATCH`**: Incremented for backward-compatible bug fixes. The public API that this applies to is the CLI associated with Toolbox, the interactions with official SDKs, and the definitions in the `tools.yaml` file. diff --git a/cmd/root.go b/cmd/root.go index 1728aed80e..e0bb46c642 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -73,6 +73,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch" @@ -89,6 +90,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring" + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" @@ -119,6 +121,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules" _ "github.com/googleapis/genai-toolbox/internal/tools/http" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardelement" + _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerconversationalanalytics" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercreateprojectfile" _ "github.com/googleapis/genai-toolbox/internal/tools/looker/lookerdeleteprojectfile" @@ -166,6 +169,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqllisttables" _ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllistactivequeries" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttablefragmentation" _ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables" @@ -183,13 +187,19 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" + _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttriggers" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistviews" _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslongrunningtransactions" @@ -197,6 +207,8 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql" _ "github.com/googleapis/genai-toolbox/internal/tools/redis" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcancelbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" + _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkgetbatch" _ "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparklistbatches" _ "github.com/googleapis/genai-toolbox/internal/tools/singlestore/singlestoreexecutesql" @@ -223,6 +235,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" + _ "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" @@ -343,12 +356,12 @@ func NewCommand(opts ...Option) *Command { flags.StringVarP(&cmd.cfg.Address, "address", "a", "127.0.0.1", "Address of the interface the server will listen on.") flags.IntVarP(&cmd.cfg.Port, "port", "p", 5000, "Port the server will listen on.") - flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt.") + flags.StringVar(&cmd.tools_file, "tools_file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") // deprecate tools_file _ = flags.MarkDeprecated("tools_file", "please use --tools-file instead") - flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder.") - flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder.") - flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files.") + flags.StringVar(&cmd.tools_file, "tools-file", "", "File path specifying the tool configuration. Cannot be used with --tools-files, or --tools-folder.") + flags.StringSliceVar(&cmd.tools_files, "tools-files", []string{}, "Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file, or --tools-folder.") + flags.StringVar(&cmd.tools_folder, "tools-folder", "", "Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file, or --tools-files.") flags.Var(&cmd.cfg.LogLevel, "log-level", "Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'.") flags.Var(&cmd.cfg.LoggingFormat, "logging-format", "Specify logging format to use. Allowed: 'standard' or 'JSON'.") flags.BoolVar(&cmd.cfg.TelemetryGCP, "telemetry-gcp", false, "Enable exporting directly to Google Cloud Monitoring.") @@ -356,13 +369,14 @@ func NewCommand(opts ...Option) *Command { flags.StringVar(&cmd.cfg.TelemetryServiceName, "telemetry-service-name", "toolbox", "Sets the value of the service.name resource attribute for telemetry data.") // Fetch prebuilt tools sources to customize the help description prebuiltHelp := fmt.Sprintf( - "Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. Allowed: '%s'.", + "Use a prebuilt tool configuration by source type. Allowed: '%s'.", strings.Join(prebuiltconfigs.GetPrebuiltSources(), "', '"), ) flags.StringVar(&cmd.prebuiltConfig, "prebuilt", "", prebuiltHelp) flags.BoolVar(&cmd.cfg.Stdio, "stdio", false, "Listens via MCP STDIO instead of acting as a remote HTTP server.") flags.BoolVar(&cmd.cfg.DisableReload, "disable-reload", false, "Disables dynamic reloading of tools file.") flags.BoolVar(&cmd.cfg.UI, "ui", false, "Launches the Toolbox UI web server.") + flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") // wrap RunE command so that we have access to original Command object cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) } @@ -449,6 +463,9 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { if _, exists := merged.AuthSources[name]; exists { conflicts = append(conflicts, fmt.Sprintf("authSource '%s' (file #%d)", name, fileIndex+1)) } else { + if merged.AuthSources == nil { + merged.AuthSources = make(server.AuthServiceConfigs) + } merged.AuthSources[name] = authSource } } @@ -825,16 +842,10 @@ func run(cmd *Command) error { } }() - var toolsFile ToolsFile + var allToolsFiles []ToolsFile + // Load Prebuilt Configuration if cmd.prebuiltConfig != "" { - // Make sure --prebuilt and --tools-file/--tools-files/--tools-folder flags are mutually exclusive - if cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" { - errMsg := fmt.Errorf("--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - // Use prebuilt tools buf, err := prebuiltconfigs.Get(cmd.prebuiltConfig) if err != nil { cmd.logger.ErrorContext(ctx, err.Error()) @@ -845,72 +856,96 @@ func run(cmd *Command) error { // Append prebuilt.source to Version string for the User Agent cmd.cfg.Version += "+prebuilt." + cmd.prebuiltConfig - toolsFile, err = parseToolsFile(ctx, buf) + parsed, err := parseToolsFile(ctx, buf) if err != nil { errMsg := fmt.Errorf("unable to parse prebuilt tool configuration: %w", err) cmd.logger.ErrorContext(ctx, errMsg.Error()) return errMsg } - } else if len(cmd.tools_files) > 0 { - // Make sure --tools-file, --tools-files, and --tools-folder flags are mutually exclusive - if cmd.tools_file != "" || cmd.tools_folder != "" { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - // Use multiple tools files - cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) - var err error - toolsFile, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) - if err != nil { - cmd.logger.ErrorContext(ctx, err.Error()) - return err - } - } else if cmd.tools_folder != "" { - // Make sure --tools-folder and other flags are mutually exclusive - if cmd.tools_file != "" || len(cmd.tools_files) > 0 { - errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - // Use tools folder - cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) - var err error - toolsFile, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) - if err != nil { - cmd.logger.ErrorContext(ctx, err.Error()) - return err - } - } else { - // Set default value of tools-file flag to tools.yaml - if cmd.tools_file == "" { - cmd.tools_file = "tools.yaml" - } - - // Read single tool file contents - buf, err := os.ReadFile(cmd.tools_file) - if err != nil { - errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } - - toolsFile, err = parseToolsFile(ctx, buf) - if err != nil { - errMsg := fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) - cmd.logger.ErrorContext(ctx, errMsg.Error()) - return errMsg - } + allToolsFiles = append(allToolsFiles, parsed) } - cmd.cfg.SourceConfigs, cmd.cfg.AuthServiceConfigs, cmd.cfg.ToolConfigs, cmd.cfg.ToolsetConfigs, cmd.cfg.PromptConfigs = toolsFile.Sources, toolsFile.AuthServices, toolsFile.Tools, toolsFile.Toolsets, toolsFile.Prompts + // Determine if Custom Files should be loaded + // Check for explicit custom flags + isCustomConfigured := cmd.tools_file != "" || len(cmd.tools_files) > 0 || cmd.tools_folder != "" - authSourceConfigs := toolsFile.AuthSources + // Determine if default 'tools.yaml' should be used (No prebuilt AND No custom flags) + useDefaultToolsFile := cmd.prebuiltConfig == "" && !isCustomConfigured + + if useDefaultToolsFile { + cmd.tools_file = "tools.yaml" + isCustomConfigured = true + } + + // Load Custom Configurations + if isCustomConfigured { + // Enforce exclusivity among custom flags (tools-file vs tools-files vs tools-folder) + if (cmd.tools_file != "" && len(cmd.tools_files) > 0) || + (cmd.tools_file != "" && cmd.tools_folder != "") || + (len(cmd.tools_files) > 0 && cmd.tools_folder != "") { + errMsg := fmt.Errorf("--tools-file, --tools-files, and --tools-folder flags cannot be used simultaneously") + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + + var customTools ToolsFile + var err error + + if len(cmd.tools_files) > 0 { + // Use tools-files + cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging %d tool configuration files", len(cmd.tools_files))) + customTools, err = loadAndMergeToolsFiles(ctx, cmd.tools_files) + } else if cmd.tools_folder != "" { + // Use tools-folder + cmd.logger.InfoContext(ctx, fmt.Sprintf("Loading and merging all YAML files from directory: %s", cmd.tools_folder)) + customTools, err = loadAndMergeToolsFolder(ctx, cmd.tools_folder) + } else { + // Use single file (tools-file or default `tools.yaml`) + buf, readFileErr := os.ReadFile(cmd.tools_file) + if readFileErr != nil { + errMsg := fmt.Errorf("unable to read tool file at %q: %w", cmd.tools_file, readFileErr) + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + customTools, err = parseToolsFile(ctx, buf) + if err != nil { + err = fmt.Errorf("unable to parse tool file at %q: %w", cmd.tools_file, err) + } + } + + if err != nil { + cmd.logger.ErrorContext(ctx, err.Error()) + return err + } + allToolsFiles = append(allToolsFiles, customTools) + } + + // Merge Everything + // This will error if custom tools collide with prebuilt tools + finalToolsFile, err := mergeToolsFiles(allToolsFiles...) + if err != nil { + cmd.logger.ErrorContext(ctx, err.Error()) + return err + } + + cmd.cfg.SourceConfigs = finalToolsFile.Sources + cmd.cfg.AuthServiceConfigs = finalToolsFile.AuthServices + cmd.cfg.ToolConfigs = finalToolsFile.Tools + cmd.cfg.ToolsetConfigs = finalToolsFile.Toolsets + cmd.cfg.PromptConfigs = finalToolsFile.Prompts + + authSourceConfigs := finalToolsFile.AuthSources if authSourceConfigs != nil { cmd.logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` instead") - cmd.cfg.AuthServiceConfigs = authSourceConfigs + + for k, v := range authSourceConfigs { + if _, exists := cmd.cfg.AuthServiceConfigs[k]; exists { + errMsg := fmt.Errorf("resource conflict detected: authSource '%s' has the same name as an existing authService. Please rename your authSource", k) + cmd.logger.ErrorContext(ctx, errMsg.Error()) + return errMsg + } + cmd.cfg.AuthServiceConfigs[k] = v + } } instrumentation, err := telemetry.CreateTelemetryInstrumentation(versionString) @@ -961,9 +996,8 @@ func run(cmd *Command) error { }() } - watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) - - if !cmd.cfg.DisableReload { + if isCustomConfigured && !cmd.cfg.DisableReload { + watchDirs, watchedFiles := resolveWatcherInputs(cmd.tools_file, cmd.tools_files, cmd.tools_folder) // start watching the file(s) or folder for changes to trigger dynamic reloading go watchChanges(ctx, watchDirs, watchedFiles, s) } diff --git a/cmd/root_test.go b/cmd/root_test.go index d698a11db4..6036c9c478 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -63,6 +63,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig { if c.TelemetryServiceName == "" { c.TelemetryServiceName = "toolbox" } + if c.AllowedOrigins == nil { + c.AllowedOrigins = []string{"*"} + } return c } @@ -89,6 +92,21 @@ func invokeCommand(args []string) (*Command, string, error) { return c, buf.String(), err } +// invokeCommandWithContext executes the command with a context and returns the captured output. +func invokeCommandWithContext(ctx context.Context, args []string) (*Command, string, error) { + // Capture output using a buffer + buf := new(bytes.Buffer) + c := NewCommand(WithStreams(buf, buf)) + + c.SetArgs(args) + c.SilenceUsage = true + c.SilenceErrors = true + c.SetContext(ctx) + + err := c.Execute() + return c, buf.String(), err +} + func TestVersion(t *testing.T) { data, err := os.ReadFile("version.txt") if err != nil { @@ -194,6 +212,13 @@ func TestServerConfigFlags(t *testing.T) { DisableReload: true, }), }, + { + desc: "allowed origin", + args: []string{"--allowed-origins", "http://foo.com,http://bar.com"}, + want: withDefaults(server.ServerConfig{ + AllowedOrigins: []string{"http://foo.com", "http://bar.com"}, + }), + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -1448,7 +1473,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ Name: "cloud_sql_postgres_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck"}, + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"}, }, }, }, @@ -1458,7 +1483,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ Name: "cloud_sql_mysql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation"}, + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"}, }, }, }, @@ -1468,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ Name: "cloud_sql_mssql_admin_tools", - ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation"}, + ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"}, }, }, }, @@ -1478,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "alloydb_postgres_database_tools": tools.ToolsetConfig{ Name: "alloydb_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"}, + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"}, }, }, }, @@ -1508,7 +1533,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "cloud_sql_postgres_database_tools": tools.ToolsetConfig{ Name: "cloud_sql_postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"}, + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"}, }, }, }, @@ -1548,7 +1573,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "serverless_spark_tools": tools.ToolsetConfig{ Name: "serverless_spark_tools", - ToolNames: []string{"list_batches", "get_batch", "cancel_batch"}, + ToolNames: []string{"list_batches", "get_batch", "cancel_batch", "create_pyspark_batch", "create_spark_batch"}, }, }, }, @@ -1588,7 +1613,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "looker_tools": tools.ToolsetConfig{ Name: "looker_tools", - ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, + ToolNames: []string{"get_models", "get_explores", "get_dimensions", "get_measures", "get_filters", "get_parameters", "query", "query_sql", "query_url", "get_looks", "run_look", "make_look", "get_dashboards", "run_dashboard", "make_dashboard", "add_dashboard_element", "add_dashboard_filter", "generate_embed_url", "health_pulse", "health_analyze", "health_vacuum", "dev_mode", "get_projects", "get_project_files", "get_project_file", "create_project_file", "update_project_file", "delete_project_file", "get_connections", "get_connection_schemas", "get_connection_databases", "get_connection_tables", "get_connection_table_columns"}, }, }, }, @@ -1608,7 +1633,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "postgres_database_tools": tools.ToolsetConfig{ Name: "postgres_database_tools", - ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"}, + ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"}, }, }, }, @@ -1618,7 +1643,7 @@ func TestPrebuiltTools(t *testing.T) { wantToolset: server.ToolsetConfigs{ "spanner-database-tools": tools.ToolsetConfig{ Name: "spanner-database-tools", - ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables"}, + ToolNames: []string{"execute_sql", "execute_sql_dql", "list_tables", "list_graphs"}, }, }, }, @@ -1745,11 +1770,6 @@ func TestMutuallyExclusiveFlags(t *testing.T) { args []string errString string }{ - { - desc: "--prebuilt and --tools-file", - args: []string{"--prebuilt", "alloydb", "--tools-file", "my.yaml"}, - errString: "--prebuilt and --tools-file/--tools-files/--tools-folder flags cannot be used simultaneously", - }, { desc: "--tools-file and --tools-files", args: []string{"--tools-file", "my.yaml", "--tools-files", "a.yaml,b.yaml"}, @@ -1892,3 +1912,228 @@ func TestMergeToolsFiles(t *testing.T) { }) } } +func TestPrebuiltAndCustomTools(t *testing.T) { + t.Setenv("SQLITE_DATABASE", "test.db") + // Setup custom tools file + customContent := ` +tools: + custom_tool: + kind: http + source: my-http + method: GET + path: / + description: "A custom tool for testing" +sources: + my-http: + kind: http + baseUrl: http://example.com +` + customFile := filepath.Join(t.TempDir(), "custom.yaml") + if err := os.WriteFile(customFile, []byte(customContent), 0644); err != nil { + t.Fatal(err) + } + + // Tool Conflict File + // SQLite prebuilt has a tool named 'list_tables' + toolConflictContent := ` +tools: + list_tables: + kind: http + source: my-http + method: GET + path: / + description: "Conflicting tool" +sources: + my-http: + kind: http + baseUrl: http://example.com +` + toolConflictFile := filepath.Join(t.TempDir(), "tool_conflict.yaml") + if err := os.WriteFile(toolConflictFile, []byte(toolConflictContent), 0644); err != nil { + t.Fatal(err) + } + + // Source Conflict File + // SQLite prebuilt has a source named 'sqlite-source' + sourceConflictContent := ` +sources: + sqlite-source: + kind: http + baseUrl: http://example.com +tools: + dummy_tool: + kind: http + source: sqlite-source + method: GET + path: / + description: "Dummy" +` + sourceConflictFile := filepath.Join(t.TempDir(), "source_conflict.yaml") + if err := os.WriteFile(sourceConflictFile, []byte(sourceConflictContent), 0644); err != nil { + t.Fatal(err) + } + + // Toolset Conflict File + // SQLite prebuilt has a toolset named 'sqlite_database_tools' + toolsetConflictContent := ` +sources: + dummy-src: + kind: http + baseUrl: http://example.com +tools: + dummy_tool: + kind: http + source: dummy-src + method: GET + path: / + description: "Dummy" +toolsets: + sqlite_database_tools: + - dummy_tool +` + toolsetConflictFile := filepath.Join(t.TempDir(), "toolset_conflict.yaml") + if err := os.WriteFile(toolsetConflictFile, []byte(toolsetConflictContent), 0644); err != nil { + t.Fatal(err) + } + + //Legacy Auth File + authContent := ` +authSources: + legacy-auth: + kind: google + clientId: "test-client-id" +` + authFile := filepath.Join(t.TempDir(), "auth.yaml") + if err := os.WriteFile(authFile, []byte(authContent), 0644); err != nil { + t.Fatal(err) + } + + testCases := []struct { + desc string + args []string + wantErr bool + errString string + cfgCheck func(server.ServerConfig) error + }{ + { + desc: "success mixed", + args: []string{"--prebuilt", "sqlite", "--tools-file", customFile}, + wantErr: false, + cfgCheck: func(cfg server.ServerConfig) error { + if _, ok := cfg.ToolConfigs["custom_tool"]; !ok { + return fmt.Errorf("custom tool not found") + } + if _, ok := cfg.ToolConfigs["list_tables"]; !ok { + return fmt.Errorf("prebuilt tool 'list_tables' not found") + } + return nil + }, + }, + { + desc: "tool conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", toolConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "source conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", sourceConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "toolset conflict error", + args: []string{"--prebuilt", "sqlite", "--tools-file", toolsetConflictFile}, + wantErr: true, + errString: "resource conflicts detected", + }, + { + desc: "legacy auth additive", + args: []string{"--prebuilt", "sqlite", "--tools-file", authFile}, + wantErr: false, + cfgCheck: func(cfg server.ServerConfig) error { + if _, ok := cfg.AuthServiceConfigs["legacy-auth"]; !ok { + return fmt.Errorf("legacy auth source not merged into auth services") + } + return nil + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + cmd, output, err := invokeCommandWithContext(ctx, tc.args) + + if tc.wantErr { + if err == nil { + t.Fatalf("expected an error but got none") + } + if !strings.Contains(err.Error(), tc.errString) { + t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error()) + } + } else { + if err != nil && err != context.DeadlineExceeded && err != context.Canceled { + t.Fatalf("unexpected error: %v", err) + } + if !strings.Contains(output, "Server ready to serve!") { + t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) + } + if tc.cfgCheck != nil { + if err := tc.cfgCheck(cmd.cfg); err != nil { + t.Errorf("config check failed: %v", err) + } + } + } + }) + } +} + +func TestDefaultToolsFileBehavior(t *testing.T) { + t.Setenv("SQLITE_DATABASE", "test.db") + testCases := []struct { + desc string + args []string + expectRun bool + errString string + }{ + { + desc: "no flags (defaults to tools.yaml)", + args: []string{}, + expectRun: false, + errString: "tools.yaml", // Expect error because tools.yaml doesn't exist in test env + }, + { + desc: "prebuilt only (skips tools.yaml)", + args: []string{"--prebuilt", "sqlite"}, + expectRun: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + _, output, err := invokeCommandWithContext(ctx, tc.args) + + if tc.expectRun { + if err != nil && err != context.DeadlineExceeded && err != context.Canceled { + t.Fatalf("expected server start, got error: %v", err) + } + // Verify it actually started + if !strings.Contains(output, "Server ready to serve!") { + t.Errorf("server did not start successfully (no ready message found). Output:\n%s", output) + } + } else { + if err == nil { + t.Fatalf("expected error reading default file, got nil") + } + if !strings.Contains(err.Error(), tc.errString) { + t.Errorf("expected error message to contain %q, but got %q", tc.errString, err.Error()) + } + } + }) + } +} diff --git a/cmd/version.txt b/cmd/version.txt index 885415662f..2094a100ca 100644 --- a/cmd/version.txt +++ b/cmd/version.txt @@ -1 +1 @@ -0.21.0 +0.24.0 diff --git a/docs/ALLOYDBADMIN_README.md b/docs/ALLOYDBADMIN_README.md index 0304aad7c4..489dfb0cea 100644 --- a/docs/ALLOYDBADMIN_README.md +++ b/docs/ALLOYDBADMIN_README.md @@ -12,53 +12,7 @@ To connect to the database to explore and query data, search the MCP store for t ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **AlloyDB API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -67,10 +21,13 @@ To connect to the database to explore and query data, search the MCP store for t ## Install & Configuration -1. In the Antigravity MCP Store, click the "Install" button. +In the Antigravity MCP Store, click the "Install" button. You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide AlloyDB capabilities to your AI assistant. You can: @@ -104,8 +61,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "alloydb-admin": { - "command": "toolbox", - "args": ["--prebuilt", "alloydb-postgres-admin", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "alloydb-postgres-admin", "--stdio"] } } } diff --git a/docs/ALLOYDBPG_README.md b/docs/ALLOYDBPG_README.md index 46f5e3427b..2e03b785d3 100644 --- a/docs/ALLOYDBPG_README.md +++ b/docs/ALLOYDBPG_README.md @@ -15,53 +15,7 @@ For AlloyDB infrastructure management, search the MCP store for the AlloyDB for ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **AlloyDB API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -76,6 +30,9 @@ For AlloyDB infrastructure management, search the MCP store for the AlloyDB for 2. Add the required inputs for your [cluster](https://docs.cloud.google.com/alloydb/docs/cluster-list) in the configuration pop-up, then click "Save". You can update this configuration at any time in the "Configure" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + You'll now be able to see all enabled tools in the "Tools" tab. ## Usage @@ -125,8 +82,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "alloydb-postgres": { - "command": "toolbox", - "args": ["--prebuilt", "alloydb-postgres", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "alloydb-postgres", "--stdio"] } } } diff --git a/docs/BIGQUERY_README.md b/docs/BIGQUERY_README.md index 39fe888485..6a4277aec1 100644 --- a/docs/BIGQUERY_README.md +++ b/docs/BIGQUERY_README.md @@ -12,53 +12,7 @@ An editor configured to use the BigQuery MCP server can use its AI capabilities ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **BigQuery API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -70,6 +24,9 @@ An editor configured to use the BigQuery MCP server can use its AI capabilities 2. Add the required inputs in the configuration pop-up, then click "Save". You can update this configuration at any time in the "Configure" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + You'll now be able to see all enabled tools in the "Tools" tab. @@ -119,8 +76,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "bigquery": { - "command": "toolbox", - "args": ["--prebuilt", "bigquery", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "bigquery", "--stdio"] } } } diff --git a/docs/CLOUDSQLMSSQLADMIN_README.md b/docs/CLOUDSQLMSSQLADMIN_README.md index 05905df9d3..69130434aa 100644 --- a/docs/CLOUDSQLMSSQLADMIN_README.md +++ b/docs/CLOUDSQLMSSQLADMIN_README.md @@ -12,53 +12,7 @@ To connect to the database to explore and query data, search the MCP store for t ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -67,10 +21,13 @@ To connect to the database to explore and query data, search the MCP store for t ## Install & Configuration -1. In the Antigravity MCP Store, click the "Install" button. +In the Antigravity MCP Store, click the "Install" button. You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for SQL Server capabilities to your AI assistant. You can: @@ -101,8 +58,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-sqlserver-admin": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-mssql-admin", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-mssql-admin", "--stdio"] } } } diff --git a/docs/CLOUDSQLMSSQL_README.md b/docs/CLOUDSQLMSSQL_README.md index 9170faf225..9b1385f8b8 100644 --- a/docs/CLOUDSQLMSSQL_README.md +++ b/docs/CLOUDSQLMSSQL_README.md @@ -13,53 +13,7 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -75,6 +29,9 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for SQL Server capabilities to your AI assistant. You can: @@ -112,8 +69,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-mssql": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-mssql", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-mssql", "--stdio"], "env": { "CLOUD_SQL_MSSQL_PROJECT": "your-project-id", "CLOUD_SQL_MSSQL_REGION": "your-region", diff --git a/docs/CLOUDSQLMYSQLADMIN_README.md b/docs/CLOUDSQLMYSQLADMIN_README.md index 3514f07f24..5ab9258a21 100644 --- a/docs/CLOUDSQLMYSQLADMIN_README.md +++ b/docs/CLOUDSQLMYSQLADMIN_README.md @@ -12,53 +12,7 @@ To connect to the database to explore and query data, search the MCP store for t ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -67,10 +21,13 @@ To connect to the database to explore and query data, search the MCP store for t ## Install & Configuration -1. In the Antigravity MCP Store, click the "Install" button. +In the Antigravity MCP Store, click the "Install" button. You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for MySQL capabilities to your AI assistant. You can: @@ -100,8 +57,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-mysql-admin": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-mysql-admin", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-mysql-admin", "--stdio"] } } } diff --git a/docs/CLOUDSQLMYSQL_README.md b/docs/CLOUDSQLMYSQL_README.md index 383c35fb57..4eb5ee1975 100644 --- a/docs/CLOUDSQLMYSQL_README.md +++ b/docs/CLOUDSQLMYSQL_README.md @@ -15,53 +15,7 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -77,6 +31,9 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for MySQL capabilities to your AI assistant. You can: @@ -118,8 +75,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-mysql": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-mysql", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-mysql", "--stdio"], "env": { "CLOUD_SQL_MYSQL_PROJECT": "your-project-id", "CLOUD_SQL_MYSQL_REGION": "your-region", diff --git a/docs/CLOUDSQLPGADMIN_README.md b/docs/CLOUDSQLPGADMIN_README.md index 9ab385eeea..c1b594ea49 100644 --- a/docs/CLOUDSQLPGADMIN_README.md +++ b/docs/CLOUDSQLPGADMIN_README.md @@ -12,53 +12,7 @@ To connect to the database to explore and query data, search the MCP store for t ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -67,10 +21,13 @@ To connect to the database to explore and query data, search the MCP store for t ## Install & Configuration -1. In the Antigravity MCP Store, click the "Install" button. +In the Antigravity MCP Store, click the "Install" button. You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for PostgreSQL capabilities to your AI assistant. You can: @@ -100,8 +57,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-postgres-admin": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-postgres-admin", "--stdio"] + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-postgres-admin", "--stdio"] } } } diff --git a/docs/CLOUDSQLPG_README.md b/docs/CLOUDSQLPG_README.md index b1e9863a89..a30bd002b8 100644 --- a/docs/CLOUDSQLPG_README.md +++ b/docs/CLOUDSQLPG_README.md @@ -15,53 +15,7 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud SQL Admin API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -77,6 +31,9 @@ For Cloud SQL infrastructure management, search the MCP store for the Cloud SQL You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud SQL for PostgreSQL capabilities to your AI assistant. You can: @@ -130,8 +87,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "cloud-sql-postgres": { - "command": "toolbox", - "args": ["--prebuilt", "cloud-sql-postgres", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "cloud-sql-postgres", "--stdio"], "env": { "CLOUD_SQL_POSTGRES_PROJECT": "your-project-id", "CLOUD_SQL_POSTGRES_REGION": "your-region", diff --git a/docs/DATAPLEX_README.md b/docs/DATAPLEX_README.md index d0253d7291..f40f4eb909 100644 --- a/docs/DATAPLEX_README.md +++ b/docs/DATAPLEX_README.md @@ -11,53 +11,7 @@ An editor configured to use the Dataplex MCP server can use its AI capabilities ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Dataplex API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -71,6 +25,9 @@ An editor configured to use the Dataplex MCP server can use its AI capabilities You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Dataplex capabilities to your AI assistant. You can: @@ -102,8 +59,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "dataplex": { - "command": "toolbox", - "args": ["--prebuilt", "dataplex", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "dataplex", "--stdio"], "env": { "DATAPLEX_PROJECT": "your-project-id" } diff --git a/docs/LOOKER_README.md b/docs/LOOKER_README.md index ffbb359a67..e1fc6bab24 100644 --- a/docs/LOOKER_README.md +++ b/docs/LOOKER_README.md @@ -14,53 +14,7 @@ An editor configured to use the Looker MCP server can use its AI capabilities to ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * Access to a Looker instance. * API Credentials (`Client ID` and `Client Secret`) or OAuth configuration. @@ -72,6 +26,9 @@ An editor configured to use the Looker MCP server can use its AI capabilities to You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Looker capabilities to your AI assistant. You can: @@ -118,8 +75,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "looker": { - "command": "toolbox", - "args": ["--prebuilt", "looker", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "looker", "--stdio"], "env": { "LOOKER_BASE_URL": "https://your.looker.instance.com", "LOOKER_CLIENT_ID": "your-client-id", diff --git a/docs/SPANNER_README.md b/docs/SPANNER_README.md index 16eab5661a..f4bf32cd19 100644 --- a/docs/SPANNER_README.md +++ b/docs/SPANNER_README.md @@ -11,53 +11,7 @@ An editor configured to use the Cloud Spanner MCP server can use its AI capabili ## Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - +* [Node.js](https://nodejs.org/) installed. * A Google Cloud project with the **Cloud Spanner API** enabled. * Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. * IAM Permissions: @@ -72,6 +26,9 @@ An editor configured to use the Cloud Spanner MCP server can use its AI capabili You'll now be able to see all enabled tools in the "Tools" tab. +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. + ## Usage Once configured, the MCP server will automatically provide Cloud Spanner capabilities to your AI assistant. You can: @@ -84,11 +41,12 @@ Once configured, the MCP server will automatically provide Cloud Spanner capabil The Cloud Spanner MCP server provides the following tools: -| Tool Name | Description | -|:------------------|:-----------------------------------------------------------| -| `execute_sql` | Use this tool to execute DML SQL. | -| `execute_sql_dql` | Use this tool to execute DQL SQL. | -| `list_tables` | Lists detailed schema information for user-created tables. | +| Tool Name | Description | +|:------------------|:-----------------------------------------------------------------| +| `execute_sql` | Use this tool to execute DML SQL. | +| `execute_sql_dql` | Use this tool to execute DQL SQL. | +| `list_tables` | Lists detailed schema information for user-created tables. | +| `list_graphs` | Lists detailed graph schema information for user-created graphs. | ## Custom MCP Server Configuration @@ -107,8 +65,8 @@ Add the following configuration to your MCP client (e.g., `settings.json` for Ge { "mcpServers": { "spanner": { - "command": "toolbox", - "args": ["--prebuilt", "spanner", "--stdio"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "spanner", "--stdio"], "env": { "SPANNER_PROJECT": "your-project-id", "SPANNER_INSTANCE": "your-instance-id", diff --git a/docs/TOOLBOX_README.md b/docs/TOOLBOX_README.md index 959cd177f0..170c71e183 100644 --- a/docs/TOOLBOX_README.md +++ b/docs/TOOLBOX_README.md @@ -2,66 +2,23 @@ The MCP Toolbox for Databases Server gives AI-powered development tools the ability to work with your custom tools. It is designed to simplify and secure the development of tools for interacting with databases. -# Prerequisites -* Download and install [MCP Toolbox](https://github.com/googleapis/genai-toolbox): - 1. **Download the Toolbox binary**: - Download the latest binary for your operating system and architecture from the storage bucket. Check the [releases page](https://github.com/googleapis/genai-toolbox/releases) for additional versions: - - - * To install Toolbox as a binary on Linux (AMD64): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox - ``` +## Prerequisites - * To install Toolbox as a binary on macOS (Apple Silicon): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox - ``` - - * To install Toolbox as a binary on macOS (Intel): - ```bash - curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox - ``` - - * To install Toolbox as a binary on Windows (AMD64): - ```powershell - curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe" - ``` - - - 2. **Make it executable**: - - ```bash - chmod +x toolbox - ``` - - 3. **Add the binary to $PATH in `.~/bash_profile`** (Note: You may need to restart Antigravity for changes to take effect.): - - ```bash - export PATH=$PATH:path/to/folder - ``` - - **On Windows, move binary to the `WindowsApps\` folder**: - ``` - move "C:\Users\\toolbox.exe" "C:\Users\\AppData\Local\Microsoft\WindowsApps\" - ``` - - **Tip:** Ensure the destination folder for your binary is included in - your system's PATH environment variable. To check `PATH`, use `echo - $PATH` (or `echo %PATH%` on Windows). - -* Any required APIs and permissions for connecting to your database. - -> **Note:** If your database instance uses private IPs, you must run the MCP server in the same Virtual Private Cloud (VPC) network. +* [Node.js](https://nodejs.org/) installed. +* A Google Cloud project with relevant APIs enabled. +* Ensure [Application Default Credentials](https://cloud.google.com/docs/authentication/gcloud) are available in your environment. ## Install & Configuration -1. In the Antigravity MCP Store, click the "Install" button. +1. In the Antigravity MCP Store, click the **Install** button. A configuration window will appear. -2. Add your [`tools.yaml` configuration -file](https://googleapis.github.io/genai-toolbox/getting-started/configure/) to -the directory you are running Antigravity +2. Create your [`tools.yaml` configuration file](https://googleapis.github.io/genai-toolbox/getting-started/configure/). + +3. In the configuration window, enter the full absolute path to your `tools.yaml` file and click **Save**. + +> [!NOTE] +> If you encounter issues with Windows Defender blocking the execution, you may need to configure an allowlist. See [Configure exclusions for Microsoft Defender Antivirus](https://learn.microsoft.com/en-us/microsoft-365/security/defender-endpoint/configure-exclusions-microsoft-defender-antivirus?view=o365-worldwide) for more details. ## Usage @@ -73,8 +30,8 @@ Interact with your custom tools using natural language. { "mcpServers": { "mcp-toolbox": { - "command": "toolbox", - "args": ["--tools-file", "you-tool-file.yaml"], + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--tools-file", "your-tool-file.yaml"], "env": { "ENV_VAR_NAME": "ENV_VAR_VALUE", } diff --git a/docs/en/concepts/telemetry/index.md b/docs/en/concepts/telemetry/index.md index 862c3832e2..49b7c9edca 100644 --- a/docs/en/concepts/telemetry/index.md +++ b/docs/en/concepts/telemetry/index.md @@ -183,11 +183,11 @@ Protocol (OTLP). If you would like to use a collector, please refer to this The following flags are used to determine Toolbox's telemetry configuration: -| **flag** | **type** | **description** | -|----------------------------|----------|------------------------------------------------------------------------------------------------------------------| -| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. | -| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. ""). | -| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. | +| **flag** | **type** | **description** | +|----------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. | +| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "127.0.0.1:4318"). To pass an insecure endpoint here, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. | +| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. | In addition to the flags noted above, you can also make additional configuration for OpenTelemetry via the [General SDK Configuration][sdk-configuration] through @@ -207,5 +207,5 @@ To enable Google Cloud Exporter: To enable OTLP Exporter, provide Collector endpoint: ```bash -./toolbox --telemetry-otlp="http://127.0.0.1:4553" +./toolbox --telemetry-otlp="127.0.0.1:4553" ``` diff --git a/docs/en/getting-started/colab_quickstart.ipynb b/docs/en/getting-started/colab_quickstart.ipynb index b9bb327338..a2e2f989e0 100644 --- a/docs/en/getting-started/colab_quickstart.ipynb +++ b/docs/en/getting-started/colab_quickstart.ipynb @@ -234,7 +234,7 @@ }, "outputs": [], "source": [ - "version = \"0.21.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/getting-started/introduction/_index.md b/docs/en/getting-started/introduction/_index.md index 88ac7d4ee2..f5f7d76836 100644 --- a/docs/en/getting-started/introduction/_index.md +++ b/docs/en/getting-started/introduction/_index.md @@ -71,6 +71,22 @@ redeploying your application. ## Getting Started +### (Non-production) Running Toolbox + +You can run Toolbox directly with a [configuration file](../configure.md): + +```sh +npx @toolbox-sdk/server --tools-file tools.yaml +``` + +This runs the latest version of the toolbox server with your configuration file. + +{{< notice note >}} +This method should only be used for non-production use cases such as +experimentation. For any production use-cases, please consider [Installing the +server](#installing-the-server) and then [running it](#running-the-server). +{{< /notice >}} + ### Installing the server For the latest version, check the [releases page][releases] and use the @@ -87,7 +103,7 @@ To install Toolbox as a binary on Linux (AMD64): ```sh # see releases page for other versions -export VERSION=0.21.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/linux/amd64/toolbox chmod +x toolbox ``` @@ -98,7 +114,7 @@ To install Toolbox as a binary on macOS (Apple Silicon): ```sh # see releases page for other versions -export VERSION=0.21.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/arm64/toolbox chmod +x toolbox ``` @@ -109,19 +125,29 @@ To install Toolbox as a binary on macOS (Intel): ```sh # see releases page for other versions -export VERSION=0.21.0 +export VERSION=0.24.0 curl -L -o toolbox https://storage.googleapis.com/genai-toolbox/v$VERSION/darwin/amd64/toolbox chmod +x toolbox ``` {{% /tab %}} -{{% tab header="Windows (AMD64)" lang="en" %}} -To install Toolbox as a binary on Windows (AMD64): +{{% tab header="Windows (Command Prompt)" lang="en" %}} +To install Toolbox as a binary on Windows (Command Prompt): + +```cmd +:: see releases page for other versions +set VERSION=0.24.0 +curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" +``` + +{{% /tab %}} +{{% tab header="Windows (PowerShell)" lang="en" %}} +To install Toolbox as a binary on Windows (PowerShell): ```powershell -:: see releases page for other versions -set VERSION=0.21.0 -curl -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v%VERSION%/windows/amd64/toolbox.exe" +# see releases page for other versions +$VERSION = "0.24.0" +curl.exe -o toolbox.exe "https://storage.googleapis.com/genai-toolbox/v$VERSION/windows/amd64/toolbox.exe" ``` {{% /tab %}} @@ -132,7 +158,7 @@ You can also install Toolbox as a container: ```sh # see releases page for other versions -export VERSION=0.21.0 +export VERSION=0.24.0 docker pull us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION ``` @@ -151,7 +177,7 @@ To install from source, ensure you have the latest version of [Go installed](https://go.dev/doc/install), and then run the following command: ```sh -go install github.com/googleapis/genai-toolbox@v0.21.0 +go install github.com/googleapis/genai-toolbox@v0.24.0 ``` {{% /tab %}} @@ -294,6 +320,10 @@ let client = new ToolboxClient(URL); const toolboxTools = await client.loadToolset('toolsetName'); {{< /highlight >}} +For more detailed instructions on using the Toolbox Core SDK, see the +[project's +README](https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md). + {{% /tab %}} {{% tab header="LangChain/Langraph" lang="en" %}} @@ -318,6 +348,10 @@ const getTool = (toolboxTool) => tool(currTool, { const tools = toolboxTools.map(getTool); {{< /highlight >}} +For more detailed instructions on using the Toolbox Core SDK, see the +[project's +README](https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md). + {{% /tab %}} {{% tab header="Genkit" lang="en" %}} @@ -353,6 +387,10 @@ const getTool = (toolboxTool) => ai.defineTool({ const tools = toolboxTools.map(getTool); {{< /highlight >}} +For more detailed instructions on using the Toolbox Core SDK, see the +[project's +README](https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md). + {{% /tab %}} {{% tab header="LlamaIndex" lang="en" %}} @@ -380,13 +418,33 @@ const tools = toolboxTools.map(getTool); {{< /highlight >}} -{{% /tab %}} -{{< /tabpane >}} - For more detailed instructions on using the Toolbox Core SDK, see the [project's README](https://github.com/googleapis/mcp-toolbox-sdk-js/blob/main/packages/toolbox-core/README.md). +{{% /tab %}} +{{% tab header="ADK TS" lang="en" %}} + +{{< highlight javascript >}} +import { ToolboxClient } from '@toolbox-sdk/adk'; + +// Replace with the actual URL where your Toolbox service is running +const URL = 'http://127.0.0.1:5000'; + +let client = new ToolboxClient(URL); +const tools = await client.loadToolset(); + +// Use the client and tools as per requirement + +{{< /highlight >}} + +For detailed samples on using the Toolbox JS SDK with ADK JS, see the [project's +README.](https://github.com/googleapis/mcp-toolbox-sdk-js/tree/main/packages/toolbox-adk/README.md) + +{{% /tab %}} +{{< /tabpane >}} + + #### Go Once you've installed the [Toolbox Go diff --git a/docs/en/getting-started/local_quickstart_js.md b/docs/en/getting-started/local_quickstart_js.md index 5a8b06c839..bfec5be5d9 100644 --- a/docs/en/getting-started/local_quickstart_js.md +++ b/docs/en/getting-started/local_quickstart_js.md @@ -40,11 +40,24 @@ from Toolbox. ``` 1. In a new terminal, install the - [SDK](https://www.npmjs.com/package/@toolbox-sdk/core). - - ```bash - npm install @toolbox-sdk/core - ``` + SDK package. + {{< tabpane persist=header >}} +{{< tab header="LangChain" lang="bash" >}} +npm install @toolbox-sdk/core +{{< /tab >}} +{{< tab header="GenkitJS" lang="bash" >}} +npm install @toolbox-sdk/core +{{< /tab >}} +{{< tab header="LlamaIndex" lang="bash" >}} +npm install @toolbox-sdk/core +{{< /tab >}} +{{< tab header="GoogleGenAI" lang="bash" >}} +npm install @toolbox-sdk/core +{{< /tab >}} +{{< tab header="ADK" lang="bash" >}} +npm install @toolbox-sdk/adk +{{< /tab >}} +{{< /tabpane >}} 1. Install other required dependencies @@ -61,6 +74,9 @@ npm install llamaindex @llamaindex/google @llamaindex/workflow {{< tab header="GoogleGenAI" lang="bash" >}} npm install @google/genai {{< /tab >}} +{{< tab header="ADK" lang="bash" >}} +npm install @google/adk +{{< /tab >}} {{< /tabpane >}} 1. Create a new file named `hotelAgent.js` and copy the following code to create @@ -91,6 +107,12 @@ npm install @google/genai {{< /tab >}} +{{< tab header="ADK" lang="js" >}} + +{{< include "quickstart/js/adk/quickstart.js" >}} + +{{< /tab >}} + {{< /tabpane >}} 1. Run your agent, and observe the results: diff --git a/docs/en/getting-started/mcp_quickstart/_index.md b/docs/en/getting-started/mcp_quickstart/_index.md index 929d6a9be1..f07528d2bf 100644 --- a/docs/en/getting-started/mcp_quickstart/_index.md +++ b/docs/en/getting-started/mcp_quickstart/_index.md @@ -105,7 +105,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/getting-started/prompts_quickstart_gemini_cli.md b/docs/en/getting-started/prompts_quickstart_gemini_cli.md new file mode 100644 index 0000000000..2061acd7fa --- /dev/null +++ b/docs/en/getting-started/prompts_quickstart_gemini_cli.md @@ -0,0 +1,245 @@ +--- +title: "Prompts using Gemini CLI" +type: docs +weight: 5 +description: > + How to get started using Toolbox prompts locally with PostgreSQL and [Gemini CLI](https://pypi.org/project/gemini-cli/). +--- + +## Before you begin + +This guide assumes you have already done the following: + +1. Installed [PostgreSQL 16+ and the `psql` client][install-postgres]. + +[install-postgres]: https://www.postgresql.org/download/ + +## Step 1: Set up your database + +In this section, we will create a database, insert some data that needs to be +accessed by our agent, and create a database user for Toolbox to connect with. + +1. Connect to postgres using the `psql` command: + + ```bash + psql -h 127.0.0.1 -U postgres + ``` + + Here, `postgres` denotes the default postgres superuser. + + {{< notice info >}} + +#### **Having trouble connecting?** + +* **Password Prompt:** If you are prompted for a password for the `postgres` + user and do not know it (or a blank password doesn't work), your PostgreSQL + installation might require a password or a different authentication method. +* **`FATAL: role "postgres" does not exist`:** This error means the default + `postgres` superuser role isn't available under that name on your system. +* **`Connection refused`:** Ensure your PostgreSQL server is actually running. + You can typically check with `sudo systemctl status postgresql` and start it + with `sudo systemctl start postgresql` on Linux systems. + +
+ +#### **Common Solution** + +For password issues or if the `postgres` role seems inaccessible directly, try +switching to the `postgres` operating system user first. This user often has +permission to connect without a password for local connections (this is called +peer authentication). + +```bash +sudo -i -u postgres +psql -h 127.0.0.1 +``` + +Once you are in the `psql` shell using this method, you can proceed with the +database creation steps below. Afterwards, type `\q` to exit `psql`, and then +`exit` to return to your normal user shell. + +If desired, once connected to `psql` as the `postgres` OS user, you can set a +password for the `postgres` *database* user using: `ALTER USER postgres WITH +PASSWORD 'your_chosen_password';`. This would allow direct connection with `-U +postgres` and a password next time. + {{< /notice >}} + +1. Create a new database and a new user: + + {{< notice tip >}} + For a real application, it's best to follow the principle of least permission + and only grant the privileges your application needs. + {{< /notice >}} + + ```sql + CREATE USER toolbox_user WITH PASSWORD 'my-password'; + + CREATE DATABASE toolbox_db; + GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user; + + ALTER DATABASE toolbox_db OWNER TO toolbox_user; + ``` + +1. End the database session: + + ```bash + \q + ``` + + (If you used `sudo -i -u postgres` and then `psql`, remember you might also + need to type `exit` after `\q` to leave the `postgres` user's shell + session.) + +1. Connect to your database with your new user: + + ```bash + psql -h 127.0.0.1 -U toolbox_user -d toolbox_db + ``` + +1. Create the required tables using the following commands: + + ```sql + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username VARCHAR(50) NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() + ); + + CREATE TABLE restaurants ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + location VARCHAR(100) + ); + + CREATE TABLE reviews ( + id SERIAL PRIMARY KEY, + user_id INT REFERENCES users(id), + restaurant_id INT REFERENCES restaurants(id), + rating INT CHECK (rating >= 1 AND rating <= 5), + review_text TEXT, + is_published BOOLEAN DEFAULT false, + moderation_status VARCHAR(50) DEFAULT 'pending_manual_review', + created_at TIMESTAMPTZ DEFAULT NOW() + ); + ``` + +1. Insert dummy data into the tables. + + ```sql + INSERT INTO users (id, username, email) VALUES + (123, 'jane_d', 'jane.d@example.com'), + (124, 'john_s', 'john.s@example.com'), + (125, 'sam_b', 'sam.b@example.com'); + + INSERT INTO restaurants (id, name, location) VALUES + (455, 'Pizza Palace', '123 Main St'), + (456, 'The Corner Bistro', '456 Oak Ave'), + (457, 'Sushi Spot', '789 Pine Ln'); + + INSERT INTO reviews (user_id, restaurant_id, rating, review_text, is_published, moderation_status) VALUES + (124, 455, 5, 'Best pizza in town! The crust was perfect.', true, 'approved'), + (125, 457, 4, 'Great sushi, very fresh. A bit pricey but worth it.', true, 'approved'), + (123, 457, 5, 'Absolutely loved the dragon roll. Will be back!', true, 'approved'), + (123, 456, 4, 'The atmosphere was lovely and the food was great. My photo upload might have been weird though.', false, 'pending_manual_review'), + (125, 456, 1, 'This review contains inappropriate language.', false, 'rejected'); + ``` + +1. End the database session: + + ```bash + \q + ``` + +## Step 2: Configure Toolbox + +Create a file named `tools.yaml`. This file defines the database connection, the +SQL tools available, and the prompts the agents will use. + +```yaml +sources: + my-foodiefind-db: + kind: postgres + host: 127.0.0.1 + port: 5432 + database: toolbox_db + user: toolbox_user + password: my-password +tools: + find_user_by_email: + kind: postgres-sql + source: my-foodiefind-db + description: Find a user's ID by their email address. + parameters: + - name: email + type: string + description: The email address of the user to find. + statement: SELECT id FROM users WHERE email = $1; + find_restaurant_by_name: + kind: postgres-sql + source: my-foodiefind-db + description: Find a restaurant's ID by its exact name. + parameters: + - name: name + type: string + description: The name of the restaurant to find. + statement: SELECT id FROM restaurants WHERE name = $1; + find_review_by_user_and_restaurant: + kind: postgres-sql + source: my-foodiefind-db + description: Find the full record for a specific review using the user's ID and the restaurant's ID. + parameters: + - name: user_id + type: integer + description: The numerical ID of the user. + - name: restaurant_id + type: integer + description: The numerical ID of the restaurant. + statement: SELECT * FROM reviews WHERE user_id = $1 AND restaurant_id = $2; +prompts: + investigate_missing_review: + description: "Investigates a user's missing review by finding the user, restaurant, and the review itself, then analyzing its status." + arguments: + - name: "user_email" + description: "The email of the user who wrote the review." + - name: "restaurant_name" + description: "The name of the restaurant being reviewed." + messages: + - content: >- + **Goal:** Find the review written by the user with email '{{.user_email}}' for the restaurant named '{{.restaurant_name}}' and understand its status. + **Workflow:** + 1. Use the `find_user_by_email` tool with the email '{{.user_email}}' to get the `user_id`. + 2. Use the `find_restaurant_by_name` tool with the name '{{.restaurant_name}}' to get the `restaurant_id`. + 3. Use the `find_review_by_user_and_restaurant` tool with the `user_id` and `restaurant_id` you just found. + 4. Analyze the results from the final tool call. Examine the `is_published` and `moderation_status` fields and explain the review's status to the user in a clear, human-readable sentence. +``` + +## Step 3: Connect to Gemini CLI + +Configure the Gemini CLI to talk to your local Toolbox MCP server. + +1. Open or create your Gemini settings file: `~/.gemini/settings.json`. +2. Add the following configuration to the file: + + ```json + { + "mcpServers": { + "MCPToolbox": { + "httpUrl": "http://localhost:5000/mcp" + } + }, + "mcp": { + "allowed": ["MCPToolbox"] + } + } + ``` +3. Start Gemini CLI using + ```sh + gemini + ``` + In case Gemini CLI is already running, use `/mcp refresh` to refresh the MCP server. + +4. Use gemini slash commands to run your prompt: + ```sh + /investigate_missing_review --user_email="jane.d@example.com" --restaurant_name="The Corner Bistro" + ``` diff --git a/docs/en/getting-started/quickstart/go/adkgo/go.mod b/docs/en/getting-started/quickstart/go/adkgo/go.mod index 0a76831b55..84bf3dad72 100644 --- a/docs/en/getting-started/quickstart/go/adkgo/go.mod +++ b/docs/en/getting-started/quickstart/go/adkgo/go.mod @@ -5,7 +5,7 @@ go 1.24.4 require ( github.com/googleapis/mcp-toolbox-sdk-go v0.4.0 google.golang.org/adk v0.1.0 - google.golang.org/genai v1.35.0 + google.golang.org/genai v1.36.0 ) require ( diff --git a/docs/en/getting-started/quickstart/go/adkgo/go.sum b/docs/en/getting-started/quickstart/go/adkgo/go.sum index a433a8c38f..02284fbc2f 100644 --- a/docs/en/getting-started/quickstart/go/adkgo/go.sum +++ b/docs/en/getting-started/quickstart/go/adkgo/go.sum @@ -108,8 +108,8 @@ google.golang.org/adk v0.1.0 h1:+w/fHuqRVolotOATlujRA+2DKUuDrFH2poRdEX2QjB8= google.golang.org/adk v0.1.0/go.mod h1:NvtSLoNx7UzZIiUAI1KoJQLMmt9sG3oCgiCx1TLqKFw= google.golang.org/api v0.255.0 h1:OaF+IbRwOottVCYV2wZan7KUq7UeNUQn1BcPc4K7lE4= google.golang.org/api v0.255.0/go.mod h1:d1/EtvCLdtiWEV4rAEHDHGh2bCnqsWhw+M8y2ECN4a8= -google.golang.org/genai v1.35.0 h1:Jo6g25CzVqFzGrX5mhWyBgQqXAUzxcx5jeK7U74zv9c= -google.golang.org/genai v1.35.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= +google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20251014184007-4626949a642f h1:vLd1CJuJOUgV6qijD7KT5Y2ZtC97ll4dxjTUappMnbo= google.golang.org/genproto v0.0.0-20251014184007-4626949a642f/go.mod h1:PI3KrSadr00yqfv6UDvgZGFsmLqeRIwt8x4p5Oo7CdM= google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= diff --git a/docs/en/getting-started/quickstart/go/genAI/go.mod b/docs/en/getting-started/quickstart/go/genAI/go.mod index d7580f9431..5e7b480698 100644 --- a/docs/en/getting-started/quickstart/go/genAI/go.mod +++ b/docs/en/getting-started/quickstart/go/genAI/go.mod @@ -4,7 +4,7 @@ go 1.24.6 require ( github.com/googleapis/mcp-toolbox-sdk-go v0.4.0 - google.golang.org/genai v1.35.0 + google.golang.org/genai v1.36.0 ) require ( diff --git a/docs/en/getting-started/quickstart/go/genAI/go.sum b/docs/en/getting-started/quickstart/go/genAI/go.sum index 380d08fac2..f050c83b70 100644 --- a/docs/en/getting-started/quickstart/go/genAI/go.sum +++ b/docs/en/getting-started/quickstart/go/genAI/go.sum @@ -102,8 +102,8 @@ gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/api v0.255.0 h1:OaF+IbRwOottVCYV2wZan7KUq7UeNUQn1BcPc4K7lE4= google.golang.org/api v0.255.0/go.mod h1:d1/EtvCLdtiWEV4rAEHDHGh2bCnqsWhw+M8y2ECN4a8= -google.golang.org/genai v1.35.0 h1:Jo6g25CzVqFzGrX5mhWyBgQqXAUzxcx5jeK7U74zv9c= -google.golang.org/genai v1.35.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= +google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20251014184007-4626949a642f h1:vLd1CJuJOUgV6qijD7KT5Y2ZtC97ll4dxjTUappMnbo= google.golang.org/genproto v0.0.0-20251014184007-4626949a642f/go.mod h1:PI3KrSadr00yqfv6UDvgZGFsmLqeRIwt8x4p5Oo7CdM= google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= diff --git a/docs/en/getting-started/quickstart/go/genkit/go.mod b/docs/en/getting-started/quickstart/go/genkit/go.mod index 0e323f53ad..41300f5f89 100644 --- a/docs/en/getting-started/quickstart/go/genkit/go.mod +++ b/docs/en/getting-started/quickstart/go/genkit/go.mod @@ -39,11 +39,11 @@ require ( go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/crypto v0.43.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect google.golang.org/api v0.255.0 // indirect google.golang.org/genai v1.34.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect diff --git a/docs/en/getting-started/quickstart/go/genkit/go.sum b/docs/en/getting-started/quickstart/go/genkit/go.sum index 4acd085d92..affe1b3a85 100644 --- a/docs/en/getting-started/quickstart/go/genkit/go.sum +++ b/docs/en/getting-started/quickstart/go/genkit/go.sum @@ -123,18 +123,18 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/docs/en/getting-started/quickstart/go/langchain/go.mod b/docs/en/getting-started/quickstart/go/langchain/go.mod index d942e172d8..d3e8823b40 100644 --- a/docs/en/getting-started/quickstart/go/langchain/go.mod +++ b/docs/en/getting-started/quickstart/go/langchain/go.mod @@ -33,12 +33,12 @@ require ( go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/crypto v0.43.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sync v0.17.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect golang.org/x/time v0.14.0 // indirect google.golang.org/api v0.255.0 // indirect google.golang.org/genproto v0.0.0-20251014184007-4626949a642f // indirect diff --git a/docs/en/getting-started/quickstart/go/langchain/go.sum b/docs/en/getting-started/quickstart/go/langchain/go.sum index 9b9bd2922f..0d48f39d2a 100644 --- a/docs/en/getting-started/quickstart/go/langchain/go.sum +++ b/docs/en/getting-started/quickstart/go/langchain/go.sum @@ -100,18 +100,18 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6 go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/docs/en/getting-started/quickstart/js/adk/package-lock.json b/docs/en/getting-started/quickstart/js/adk/package-lock.json new file mode 100644 index 0000000000..d921ceb367 --- /dev/null +++ b/docs/en/getting-started/quickstart/js/adk/package-lock.json @@ -0,0 +1,2590 @@ +{ + "name": "adk", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "adk", + "version": "1.0.0", + "license": "ISC", + "dependencies": { + "@google/adk": "^0.1.3", + "@toolbox-sdk/adk": "^0.1.5" + } + }, + "node_modules/@google-cloud/paginator": { + "version": "5.0.2", + "resolved": "https://registry.npmjs.org/@google-cloud/paginator/-/paginator-5.0.2.tgz", + "integrity": "sha512-DJS3s0OVH4zFDB1PzjxAsHqJT6sKVbRwwML0ZBP9PbU7Yebtu/7SWMRzvO2J3nUi9pRNITCfu4LJeooM2w4pjg==", + "license": "Apache-2.0", + "peer": true, + "dependencies": { + "arrify": "^2.0.0", + "extend": "^3.0.2" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@google-cloud/projectify": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@google-cloud/projectify/-/projectify-4.0.0.tgz", + "integrity": "sha512-MmaX6HeSvyPbWGwFq7mXdo0uQZLGBYCwziiLIGq5JVX+/bdI3SAq6bP98trV5eTWfLuvsMcIC1YJOF2vfteLFA==", + "license": "Apache-2.0", + "peer": true, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/@google-cloud/promisify": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/@google-cloud/promisify/-/promisify-4.0.0.tgz", + "integrity": "sha512-Orxzlfb9c67A15cq2JQEyVc7wEsmFBmHjZWZYQMUyJ1qivXyMwdyNOs9odi79hze+2zqdTtu1E19IM/FtqZ10g==", + "license": "Apache-2.0", + "peer": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/@google-cloud/storage": { + "version": "7.18.0", + "resolved": "https://registry.npmjs.org/@google-cloud/storage/-/storage-7.18.0.tgz", + "integrity": "sha512-r3ZwDMiz4nwW6R922Z1pwpePxyRwE5GdevYX63hRmAQUkUQJcBH/79EnQPDv5cOv1mFBgevdNWQfi3tie3dHrQ==", + "license": "Apache-2.0", + "peer": true, + "dependencies": { + "@google-cloud/paginator": "^5.0.0", + "@google-cloud/projectify": "^4.0.0", + "@google-cloud/promisify": "<4.1.0", + "abort-controller": "^3.0.0", + "async-retry": "^1.3.3", + "duplexify": "^4.1.3", + "fast-xml-parser": "^4.4.1", + "gaxios": "^6.0.2", + "google-auth-library": "^9.6.3", + "html-entities": "^2.5.2", + "mime": "^3.0.0", + "p-limit": "^3.0.1", + "retry-request": "^7.0.0", + "teeny-request": "^9.0.0", + "uuid": "^8.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/@google-cloud/storage/node_modules/uuid": { + "version": "8.3.2", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", + "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", + "license": "MIT", + "peer": true, + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/@google/adk": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/@google/adk/-/adk-0.1.3.tgz", + "integrity": "sha512-kkI/u9ofgdblBW9uE9B3YruZca1QR4uvsD1vQrDmUlTLkpur39sgjDvIuIqV6zPrzzbS6wBVXZ8BKNTANzMAiw==", + "license": "Apache-2.0", + "peerDependencies": { + "@google-cloud/storage": "^7.17.1", + "@google/genai": "1.14.0", + "@modelcontextprotocol/sdk": "1.17.5", + "@opentelemetry/api": "1.9.0", + "zod": "3.25.76" + } + }, + "node_modules/@google/genai": { + "version": "1.14.0", + "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.14.0.tgz", + "integrity": "sha512-jirYprAAJU1svjwSDVCzyVq+FrJpJd5CSxR/g2Ga/gZ0ZYZpcWjMS75KJl9y71K1mDN+tcx6s21CzCbB2R840g==", + "license": "Apache-2.0", + "dependencies": { + "google-auth-library": "^9.14.2", + "ws": "^8.18.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "@modelcontextprotocol/sdk": "^1.11.0" + }, + "peerDependenciesMeta": { + "@modelcontextprotocol/sdk": { + "optional": true + } + } + }, + "node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "license": "ISC", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@modelcontextprotocol/sdk": { + "version": "1.17.5", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.5.tgz", + "integrity": "sha512-QakrKIGniGuRVfWBdMsDea/dx1PNE739QJ7gCM41s9q+qaCYTHCdsIBXQVVXry3mfWAiaM9kT22Hyz53Uw8mfg==", + "license": "MIT", + "dependencies": { + "ajv": "^6.12.6", + "content-type": "^1.0.5", + "cors": "^2.8.5", + "cross-spawn": "^7.0.5", + "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", + "express": "^5.0.1", + "express-rate-limit": "^7.5.0", + "pkce-challenge": "^5.0.0", + "raw-body": "^3.0.0", + "zod": "^3.23.8", + "zod-to-json-schema": "^3.24.1" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@opentelemetry/api": { + "version": "1.9.0", + "resolved": "https://registry.npmjs.org/@opentelemetry/api/-/api-1.9.0.tgz", + "integrity": "sha512-3giAOQvZiH5F9bMlMiv8+GSPMeqg0dbaeo58/0SlA9sxSqZhnUtxzX9/2FzyhS9sWQf5S0GJE0AKBrFqjpeYcg==", + "license": "Apache-2.0", + "peer": true, + "engines": { + "node": ">=8.0.0" + } + }, + "node_modules/@pkgjs/parseargs": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", + "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14" + } + }, + "node_modules/@toolbox-sdk/adk": { + "version": "0.1.5", + "resolved": "https://registry.npmjs.org/@toolbox-sdk/adk/-/adk-0.1.5.tgz", + "integrity": "sha512-kvaBWyW1NEax6LlU6TBhx0qECw0rSTH1NtSPhkUgY9Q/YAi2S6gnkS/++zlPL+hdfGpDZOHeCEesjiAepdBA0g==", + "license": "Apache-2.0", + "dependencies": { + "@google/adk": "^0.1.2", + "@google/genai": "^1.14.0", + "@modelcontextprotocol/sdk": "1.17.5", + "@toolbox-sdk/core": "^0.1.2", + "axios": "^1.12.2", + "openapi-types": "^12.1.3", + "zod": "^3.24.4" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@toolbox-sdk/core": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/@toolbox-sdk/core/-/core-0.1.2.tgz", + "integrity": "sha512-AKfejUi0KK5W1L5AKwfbYDk9GNU92tRYIOIGvDI93XoZU3IWsPHsPHOoE3DFNpY3HvAUdaCZHj7V4GW0kq7dyg==", + "license": "Apache-2.0", + "dependencies": { + "axios": "^1.9.0", + "google-auth-library": "^10.0.0", + "zod": "^3.24.4" + }, + "engines": { + "node": ">=20.0.0" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/gaxios": { + "version": "7.1.3", + "resolved": "https://registry.npmjs.org/gaxios/-/gaxios-7.1.3.tgz", + "integrity": "sha512-YGGyuEdVIjqxkxVH1pUTMY/XtmmsApXrCVv5EU25iX6inEPbV+VakJfLealkBtJN69AQmh1eGOdCl9Sm1UP6XQ==", + "license": "Apache-2.0", + "dependencies": { + "extend": "^3.0.2", + "https-proxy-agent": "^7.0.1", + "node-fetch": "^3.3.2", + "rimraf": "^5.0.1" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/gcp-metadata": { + "version": "8.1.2", + "resolved": "https://registry.npmjs.org/gcp-metadata/-/gcp-metadata-8.1.2.tgz", + "integrity": "sha512-zV/5HKTfCeKWnxG0Dmrw51hEWFGfcF2xiXqcA3+J90WDuP0SvoiSO5ORvcBsifmx/FoIjgQN3oNOGaQ5PhLFkg==", + "license": "Apache-2.0", + "dependencies": { + "gaxios": "^7.0.0", + "google-logging-utils": "^1.0.0", + "json-bigint": "^1.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/google-auth-library": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-10.5.0.tgz", + "integrity": "sha512-7ABviyMOlX5hIVD60YOfHw4/CxOfBhyduaYB+wbFWCWoni4N7SLcV46hrVRktuBbZjFC9ONyqamZITN7q3n32w==", + "license": "Apache-2.0", + "dependencies": { + "base64-js": "^1.3.0", + "ecdsa-sig-formatter": "^1.0.11", + "gaxios": "^7.0.0", + "gcp-metadata": "^8.0.0", + "google-logging-utils": "^1.0.0", + "gtoken": "^8.0.0", + "jws": "^4.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/google-logging-utils": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/google-logging-utils/-/google-logging-utils-1.1.2.tgz", + "integrity": "sha512-YsFPGVgDFf4IzSwbwIR0iaFJQFmR5Jp7V1WuYSjuRgAm9yWqsMhKE9YPlL+wvFLnc/wMiFV4SQUD9Y/JMpxIxQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=14" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/gtoken": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/gtoken/-/gtoken-8.0.0.tgz", + "integrity": "sha512-+CqsMbHPiSTdtSO14O51eMNlrp9N79gmeqmXeouJOhfucAedHw9noVe/n5uJk3tbKE6a+6ZCQg3RPhVhHByAIw==", + "license": "MIT", + "dependencies": { + "gaxios": "^7.0.0", + "jws": "^4.0.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@toolbox-sdk/core/node_modules/node-fetch": { + "version": "3.3.2", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-3.3.2.tgz", + "integrity": "sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA==", + "license": "MIT", + "dependencies": { + "data-uri-to-buffer": "^4.0.0", + "fetch-blob": "^3.1.4", + "formdata-polyfill": "^4.0.10" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/node-fetch" + } + }, + "node_modules/@tootallnate/once": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-2.0.0.tgz", + "integrity": "sha512-XCuKFP5PS55gnMVu3dty8KPatLqUoy/ZYzDzAGCQ8JNFCkLXzmI7vNHCR+XpbZaMWQK/vQubr7PkYq8g470J/A==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">= 10" + } + }, + "node_modules/@types/caseless": { + "version": "0.12.5", + "resolved": "https://registry.npmjs.org/@types/caseless/-/caseless-0.12.5.tgz", + "integrity": "sha512-hWtVTC2q7hc7xZ/RLbxapMvDMgUnDvKvMOpKal4DrMyfGBUfB1oKaZlIRr6mJL+If3bAP6sV/QneGzF6tJjZDg==", + "license": "MIT", + "peer": true + }, + "node_modules/@types/node": { + "version": "24.10.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-24.10.1.tgz", + "integrity": "sha512-GNWcUTRBgIRJD5zj+Tq0fKOJ5XZajIiBroOF0yvj2bSU1WvNdYS/dn9UxwsujGW4JX06dnHyjV2y9rRaybH0iQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "undici-types": "~7.16.0" + } + }, + "node_modules/@types/request": { + "version": "2.48.13", + "resolved": "https://registry.npmjs.org/@types/request/-/request-2.48.13.tgz", + "integrity": "sha512-FGJ6udDNUCjd19pp0Q3iTiDkwhYup7J8hpMW9c4k53NrccQFFWKRho6hvtPPEhnXWKvukfwAlB6DbDz4yhH5Gg==", + "license": "MIT", + "peer": true, + "dependencies": { + "@types/caseless": "*", + "@types/node": "*", + "@types/tough-cookie": "*", + "form-data": "^2.5.5" + } + }, + "node_modules/@types/request/node_modules/form-data": { + "version": "2.5.5", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-2.5.5.tgz", + "integrity": "sha512-jqdObeR2rxZZbPSGL+3VckHMYtu+f9//KXBsVny6JSX/pa38Fy+bGjuG8eW/H6USNQWhLi8Num++cU2yOCNz4A==", + "license": "MIT", + "peer": true, + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.35", + "safe-buffer": "^5.2.1" + }, + "engines": { + "node": ">= 0.12" + } + }, + "node_modules/@types/request/node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/@types/request/node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "peer": true, + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/@types/tough-cookie": { + "version": "4.0.5", + "resolved": "https://registry.npmjs.org/@types/tough-cookie/-/tough-cookie-4.0.5.tgz", + "integrity": "sha512-/Ad8+nIOV7Rl++6f1BdKxFSMgmoqEoYbHRpPcx3JEfv8VRsQe9Z4mCXeJBzxs7mbHY/XOZZuXlRNfhpVPbs6ZA==", + "license": "MIT", + "peer": true + }, + "node_modules/abort-controller": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", + "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", + "license": "MIT", + "peer": true, + "dependencies": { + "event-target-shim": "^5.0.0" + }, + "engines": { + "node": ">=6.5" + } + }, + "node_modules/accepts": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-2.0.0.tgz", + "integrity": "sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==", + "license": "MIT", + "dependencies": { + "mime-types": "^3.0.0", + "negotiator": "^1.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/ajv": { + "version": "6.12.6", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", + "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.1", + "fast-json-stable-stringify": "^2.0.0", + "json-schema-traverse": "^0.4.1", + "uri-js": "^4.2.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ansi-regex": { + "version": "6.2.2", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz", + "integrity": "sha512-Bq3SmSpyFHaWjPk8If9yc6svM8c56dB5BAtW4Qbw5jHTwwXXcTLoRMkpDJp6VL0XzlWaCHTXrkFURMYmD0sLqg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-regex?sponsor=1" + } + }, + "node_modules/ansi-styles": { + "version": "6.2.3", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.3.tgz", + "integrity": "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==", + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/arrify": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/arrify/-/arrify-2.0.1.tgz", + "integrity": "sha512-3duEwti880xqi4eAMN8AyR4a0ByT90zoYdLlevfrvU43vb0YZwZVfxOgxWrLXXXpyugL0hNZc9G6BiB5B3nUug==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=8" + } + }, + "node_modules/async-retry": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/async-retry/-/async-retry-1.3.3.tgz", + "integrity": "sha512-wfr/jstw9xNi/0teMHrRW7dsz3Lt5ARhYNZ2ewpadnhaIp5mbALhOAP+EAdsC7t4Z6wqsDVv9+W6gm1Dk9mEyw==", + "license": "MIT", + "peer": true, + "dependencies": { + "retry": "0.13.1" + } + }, + "node_modules/asynckit": { + "version": "0.4.0", + "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", + "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", + "license": "MIT" + }, + "node_modules/axios": { + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/axios/-/axios-1.13.2.tgz", + "integrity": "sha512-VPk9ebNqPcy5lRGuSlKx752IlDatOjT9paPlm8A7yOuW2Fbvp4X3JznJtT4f0GzGLLiWE9W8onz51SqLYwzGaA==", + "license": "MIT", + "dependencies": { + "follow-redirects": "^1.15.6", + "form-data": "^4.0.4", + "proxy-from-env": "^1.1.0" + } + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==", + "license": "MIT" + }, + "node_modules/base64-js": { + "version": "1.5.1", + "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", + "integrity": "sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/bignumber.js": { + "version": "9.3.1", + "resolved": "https://registry.npmjs.org/bignumber.js/-/bignumber.js-9.3.1.tgz", + "integrity": "sha512-Ko0uX15oIUS7wJ3Rb30Fs6SkVbLmPBAKdlm7q9+ak9bbIeFf0MwuBsQV6z7+X768/cHsfg+WlysDWJcmthjsjQ==", + "license": "MIT", + "engines": { + "node": "*" + } + }, + "node_modules/body-parser": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-2.2.1.tgz", + "integrity": "sha512-nfDwkulwiZYQIGwxdy0RUmowMhKcFVcYXUU7m4QlKYim1rUtg83xm2yjZ40QjDuc291AJjjeSc9b++AWHSgSHw==", + "license": "MIT", + "dependencies": { + "bytes": "^3.1.2", + "content-type": "^1.0.5", + "debug": "^4.4.3", + "http-errors": "^2.0.0", + "iconv-lite": "^0.7.0", + "on-finished": "^2.4.1", + "qs": "^6.14.0", + "raw-body": "^3.0.1", + "type-is": "^2.0.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/buffer-equal-constant-time": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/buffer-equal-constant-time/-/buffer-equal-constant-time-1.0.1.tgz", + "integrity": "sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA==", + "license": "BSD-3-Clause" + }, + "node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/call-bind-apply-helpers": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", + "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/call-bound": { + "version": "1.0.4", + "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", + "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "get-intrinsic": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "license": "MIT" + }, + "node_modules/combined-stream": { + "version": "1.0.8", + "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", + "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", + "license": "MIT", + "dependencies": { + "delayed-stream": "~1.0.0" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/content-disposition": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-1.0.0.tgz", + "integrity": "sha512-Au9nRL8VNUut/XSzbQA38+M78dzP4D+eqg3gfJHMIHHYa3bg067xj1KxMUWj+VULbiZMowKngFFbKczUrNJ1mg==", + "license": "MIT", + "dependencies": { + "safe-buffer": "5.2.1" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/cookie-signature": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.2.2.tgz", + "integrity": "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==", + "license": "MIT", + "engines": { + "node": ">=6.6.0" + } + }, + "node_modules/cors": { + "version": "2.8.5", + "resolved": "https://registry.npmjs.org/cors/-/cors-2.8.5.tgz", + "integrity": "sha512-KIHbLJqu73RGr/hnbrO9uBeixNGuvSQjul/jdFvS/KFSIH1hWVd1ng7zOHx+YrEfInLG7q4n6GHQ9cDtxv/P6g==", + "license": "MIT", + "dependencies": { + "object-assign": "^4", + "vary": "^1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/cross-spawn": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", + "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", + "license": "MIT", + "dependencies": { + "path-key": "^3.1.0", + "shebang-command": "^2.0.0", + "which": "^2.0.1" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/data-uri-to-buffer": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/data-uri-to-buffer/-/data-uri-to-buffer-4.0.1.tgz", + "integrity": "sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A==", + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/debug": { + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/delayed-stream": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", + "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/dunder-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", + "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.1", + "es-errors": "^1.3.0", + "gopd": "^1.2.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/duplexify": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/duplexify/-/duplexify-4.1.3.tgz", + "integrity": "sha512-M3BmBhwJRZsSx38lZyhE53Csddgzl5R7xGJNk7CVddZD6CcmwMCH8J+7AprIrQKH7TonKxaCjcv27Qmf+sQ+oA==", + "license": "MIT", + "peer": true, + "dependencies": { + "end-of-stream": "^1.4.1", + "inherits": "^2.0.3", + "readable-stream": "^3.1.1", + "stream-shift": "^1.0.2" + } + }, + "node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", + "license": "MIT" + }, + "node_modules/ecdsa-sig-formatter": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/ecdsa-sig-formatter/-/ecdsa-sig-formatter-1.0.11.tgz", + "integrity": "sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ==", + "license": "Apache-2.0", + "dependencies": { + "safe-buffer": "^5.0.1" + } + }, + "node_modules/ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "license": "MIT" + }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "license": "MIT" + }, + "node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/end-of-stream": { + "version": "1.4.5", + "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.5.tgz", + "integrity": "sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==", + "license": "MIT", + "peer": true, + "dependencies": { + "once": "^1.4.0" + } + }, + "node_modules/es-define-property": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", + "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-object-atoms": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", + "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-set-tostringtag": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", + "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.6", + "has-tostringtag": "^1.0.2", + "hasown": "^2.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow==", + "license": "MIT" + }, + "node_modules/etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/event-target-shim": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", + "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=6" + } + }, + "node_modules/eventsource": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-3.0.7.tgz", + "integrity": "sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==", + "license": "MIT", + "dependencies": { + "eventsource-parser": "^3.0.1" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/express": { + "version": "5.1.0", + "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz", + "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==", + "license": "MIT", + "dependencies": { + "accepts": "^2.0.0", + "body-parser": "^2.2.0", + "content-disposition": "^1.0.0", + "content-type": "^1.0.5", + "cookie": "^0.7.1", + "cookie-signature": "^1.2.1", + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "finalhandler": "^2.1.0", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "merge-descriptors": "^2.0.0", + "mime-types": "^3.0.0", + "on-finished": "^2.4.1", + "once": "^1.4.0", + "parseurl": "^1.3.3", + "proxy-addr": "^2.0.7", + "qs": "^6.14.0", + "range-parser": "^1.2.1", + "router": "^2.2.0", + "send": "^1.1.0", + "serve-static": "^2.2.0", + "statuses": "^2.0.1", + "type-is": "^2.0.1", + "vary": "^1.1.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/express-rate-limit": { + "version": "7.5.1", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-7.5.1.tgz", + "integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==", + "license": "MIT", + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/express-rate-limit" + }, + "peerDependencies": { + "express": ">= 4.11" + } + }, + "node_modules/extend": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", + "integrity": "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==", + "license": "MIT" + }, + "node_modules/fast-deep-equal": { + "version": "3.1.3", + "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", + "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", + "license": "MIT" + }, + "node_modules/fast-json-stable-stringify": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", + "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==", + "license": "MIT" + }, + "node_modules/fast-xml-parser": { + "version": "4.5.3", + "resolved": "https://registry.npmjs.org/fast-xml-parser/-/fast-xml-parser-4.5.3.tgz", + "integrity": "sha512-RKihhV+SHsIUGXObeVy9AXiBbFwkVk7Syp8XgwN5U3JV416+Gwp/GO9i0JYKmikykgz/UHRrrV4ROuZEo/T0ig==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "peer": true, + "dependencies": { + "strnum": "^1.1.1" + }, + "bin": { + "fxparser": "src/cli/cli.js" + } + }, + "node_modules/fetch-blob": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/fetch-blob/-/fetch-blob-3.2.0.tgz", + "integrity": "sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "paypal", + "url": "https://paypal.me/jimmywarting" + } + ], + "license": "MIT", + "dependencies": { + "node-domexception": "^1.0.0", + "web-streams-polyfill": "^3.0.3" + }, + "engines": { + "node": "^12.20 || >= 14.13" + } + }, + "node_modules/finalhandler": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-2.1.0.tgz", + "integrity": "sha512-/t88Ty3d5JWQbWYgaOGCCYfXRwV1+be02WqYYlL6h0lEiUAMPM8o8qKGO01YIkOHzka2up08wvgYD0mDiI+q3Q==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "on-finished": "^2.4.1", + "parseurl": "^1.3.3", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/follow-redirects": { + "version": "1.15.11", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz", + "integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==", + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/RubenVerborgh" + } + ], + "license": "MIT", + "engines": { + "node": ">=4.0" + }, + "peerDependenciesMeta": { + "debug": { + "optional": true + } + } + }, + "node_modules/foreground-child": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.1.tgz", + "integrity": "sha512-gIXjKqtFuWEgzFRJA9WCQeSJLZDjgJUOMCMzxtvFq/37KojM1BFGufqsCy0r4qSQmYLsZYMeyRqzIWOMup03sw==", + "license": "ISC", + "dependencies": { + "cross-spawn": "^7.0.6", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/form-data": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", + "license": "MIT", + "dependencies": { + "asynckit": "^0.4.0", + "combined-stream": "^1.0.8", + "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", + "mime-types": "^2.1.12" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/form-data/node_modules/mime-db": { + "version": "1.52.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", + "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/form-data/node_modules/mime-types": { + "version": "2.1.35", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", + "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", + "license": "MIT", + "dependencies": { + "mime-db": "1.52.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/formdata-polyfill": { + "version": "4.0.10", + "resolved": "https://registry.npmjs.org/formdata-polyfill/-/formdata-polyfill-4.0.10.tgz", + "integrity": "sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g==", + "license": "MIT", + "dependencies": { + "fetch-blob": "^3.1.2" + }, + "engines": { + "node": ">=12.20.0" + } + }, + "node_modules/forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-2.0.0.tgz", + "integrity": "sha512-Rx/WycZ60HOaqLKAi6cHRKKI7zxWbJ31MhntmtwMoaTeF7XFH9hhBp8vITaMidfljRQ6eYWCKkaTK+ykVJHP2A==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/function-bind": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gaxios": { + "version": "6.7.1", + "resolved": "https://registry.npmjs.org/gaxios/-/gaxios-6.7.1.tgz", + "integrity": "sha512-LDODD4TMYx7XXdpwxAVRAIAuB0bzv0s+ywFonY46k126qzQHT9ygyoa9tncmOiQmmDrik65UYsEkv3lbfqQ3yQ==", + "license": "Apache-2.0", + "dependencies": { + "extend": "^3.0.2", + "https-proxy-agent": "^7.0.1", + "is-stream": "^2.0.0", + "node-fetch": "^2.6.9", + "uuid": "^9.0.1" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/gcp-metadata": { + "version": "6.1.1", + "resolved": "https://registry.npmjs.org/gcp-metadata/-/gcp-metadata-6.1.1.tgz", + "integrity": "sha512-a4tiq7E0/5fTjxPAaH4jpjkSv/uCaU2p5KC6HVGrvl0cDjA8iBZv4vv1gyzlmK0ZUKqwpOyQMKzZQe3lTit77A==", + "license": "Apache-2.0", + "dependencies": { + "gaxios": "^6.1.1", + "google-logging-utils": "^0.0.2", + "json-bigint": "^1.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/get-intrinsic": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", + "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", + "license": "MIT", + "dependencies": { + "call-bind-apply-helpers": "^1.0.2", + "es-define-property": "^1.0.1", + "es-errors": "^1.3.0", + "es-object-atoms": "^1.1.1", + "function-bind": "^1.1.2", + "get-proto": "^1.0.1", + "gopd": "^1.2.0", + "has-symbols": "^1.1.0", + "hasown": "^2.0.2", + "math-intrinsics": "^1.1.0" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/get-proto": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", + "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", + "license": "MIT", + "dependencies": { + "dunder-proto": "^1.0.1", + "es-object-atoms": "^1.0.0" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/google-auth-library": { + "version": "9.15.1", + "resolved": "https://registry.npmjs.org/google-auth-library/-/google-auth-library-9.15.1.tgz", + "integrity": "sha512-Jb6Z0+nvECVz+2lzSMt9u98UsoakXxA2HGHMCxh+so3n90XgYWkq5dur19JAJV7ONiJY22yBTyJB1TSkvPq9Ng==", + "license": "Apache-2.0", + "dependencies": { + "base64-js": "^1.3.0", + "ecdsa-sig-formatter": "^1.0.11", + "gaxios": "^6.1.1", + "gcp-metadata": "^6.1.0", + "gtoken": "^7.0.0", + "jws": "^4.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/google-logging-utils": { + "version": "0.0.2", + "resolved": "https://registry.npmjs.org/google-logging-utils/-/google-logging-utils-0.0.2.tgz", + "integrity": "sha512-NEgUnEcBiP5HrPzufUkBzJOD/Sxsco3rLNo1F1TNf7ieU8ryUzBhqba8r756CjLX7rn3fHl6iLEwPYuqpoKgQQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=14" + } + }, + "node_modules/gopd": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", + "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/gtoken": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/gtoken/-/gtoken-7.1.0.tgz", + "integrity": "sha512-pCcEwRi+TKpMlxAQObHDQ56KawURgyAf6jtIY046fJ5tIv3zDe/LEIubckAO8fj6JnAxLdmWkUfNyulQ2iKdEw==", + "license": "MIT", + "dependencies": { + "gaxios": "^6.0.0", + "jws": "^4.0.0" + }, + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/has-symbols": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", + "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-tostringtag": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", + "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", + "license": "MIT", + "dependencies": { + "has-symbols": "^1.0.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "license": "MIT", + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/html-entities": { + "version": "2.6.0", + "resolved": "https://registry.npmjs.org/html-entities/-/html-entities-2.6.0.tgz", + "integrity": "sha512-kig+rMn/QOVRvr7c86gQ8lWXq+Hkv6CbAH1hLu+RG338StTpE8Z0b44SDVaqVu7HGKf27frdmUYEs9hTUX/cLQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/mdevils" + }, + { + "type": "patreon", + "url": "https://patreon.com/mdevils" + } + ], + "license": "MIT", + "peer": true + }, + "node_modules/http-errors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", + "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==", + "license": "MIT", + "dependencies": { + "depd": "2.0.0", + "inherits": "2.0.4", + "setprototypeof": "1.2.0", + "statuses": "2.0.1", + "toidentifier": "1.0.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-errors/node_modules/statuses": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", + "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/http-proxy-agent": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-5.0.0.tgz", + "integrity": "sha512-n2hY8YdoRE1i7r6M0w9DIw5GgZN0G25P8zLCRQ8rjXtTU3vsNFBI/vWK/UIeE6g5MUUz6avwAPXmL6Fy9D/90w==", + "license": "MIT", + "peer": true, + "dependencies": { + "@tootallnate/once": "2", + "agent-base": "6", + "debug": "4" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/http-proxy-agent/node_modules/agent-base": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", + "integrity": "sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "debug": "4" + }, + "engines": { + "node": ">= 6.0.0" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/iconv-lite": { + "version": "0.7.0", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.7.0.tgz", + "integrity": "sha512-cf6L2Ds3h57VVmkZe+Pn+5APsT7FpqJtEhhieDCvrE2MK5Qk9MyffgQyuxQTm6BChfeZNtcOLHp9IcWRVcIcBQ==", + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC" + }, + "node_modules/ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "license": "MIT", + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/is-promise": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-4.0.0.tgz", + "integrity": "sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==", + "license": "MIT" + }, + "node_modules/is-stream": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/is-stream/-/is-stream-2.0.1.tgz", + "integrity": "sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==", + "license": "MIT", + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/isexe": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", + "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "license": "ISC" + }, + "node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, + "node_modules/json-bigint": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-bigint/-/json-bigint-1.0.0.tgz", + "integrity": "sha512-SiPv/8VpZuWbvLSMtTDU8hEfrZWg/mH/nV/b4o0CYbSxu1UIQPLdwKOCIyLQX+VIPO5vrLX3i8qtqFyhdPSUSQ==", + "license": "MIT", + "dependencies": { + "bignumber.js": "^9.0.0" + } + }, + "node_modules/json-schema-traverse": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz", + "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==", + "license": "MIT" + }, + "node_modules/jwa": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/jwa/-/jwa-2.0.1.tgz", + "integrity": "sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg==", + "license": "MIT", + "dependencies": { + "buffer-equal-constant-time": "^1.0.1", + "ecdsa-sig-formatter": "1.0.11", + "safe-buffer": "^5.0.1" + } + }, + "node_modules/jws": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/jws/-/jws-4.0.0.tgz", + "integrity": "sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==", + "license": "MIT", + "dependencies": { + "jwa": "^2.0.0", + "safe-buffer": "^5.0.1" + } + }, + "node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", + "license": "ISC" + }, + "node_modules/math-intrinsics": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", + "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/media-typer": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-1.1.0.tgz", + "integrity": "sha512-aisnrDP4GNe06UcKFnV5bfMNPBUw4jsLGaWwWfnH3v02GnBuXX2MCVn5RbrWo0j3pczUilYblq7fQ7Nw2t5XKw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/merge-descriptors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-2.0.0.tgz", + "integrity": "sha512-Snk314V5ayFLhp3fkUREub6WtjBfPdCPY1Ln8/8munuLuiYhsABgBVWsozAG+MWMbVEvcdcpbi9R7ww22l9Q3g==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/mime": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/mime/-/mime-3.0.0.tgz", + "integrity": "sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A==", + "license": "MIT", + "peer": true, + "bin": { + "mime": "cli.js" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/mime-db": { + "version": "1.54.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.54.0.tgz", + "integrity": "sha512-aU5EJuIN2WDemCcAp2vFBfp/m4EAhWJnUNSSw0ixs7/kXbd6Pg64EmwJkNdFhB8aWt1sH2CTXrLxo/iAGV3oPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-3.0.1.tgz", + "integrity": "sha512-xRc4oEhT6eaBpU1XF7AjpOFD+xQmXNB5OVKwp4tqCuBpHLS/ZbBDrc07mYTDqVMg6PfxUjjNp85O6Cd2Z/5HWA==", + "license": "MIT", + "dependencies": { + "mime-db": "^1.54.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/minipass": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", + "license": "ISC", + "engines": { + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/negotiator": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-1.0.0.tgz", + "integrity": "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/node-domexception": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/node-domexception/-/node-domexception-1.0.0.tgz", + "integrity": "sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ==", + "deprecated": "Use your platform's native DOMException instead", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/jimmywarting" + }, + { + "type": "github", + "url": "https://paypal.me/jimmywarting" + } + ], + "license": "MIT", + "engines": { + "node": ">=10.5.0" + } + }, + "node_modules/node-fetch": { + "version": "2.7.0", + "resolved": "https://registry.npmjs.org/node-fetch/-/node-fetch-2.7.0.tgz", + "integrity": "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==", + "license": "MIT", + "dependencies": { + "whatwg-url": "^5.0.0" + }, + "engines": { + "node": "4.x || >=6.0.0" + }, + "peerDependencies": { + "encoding": "^0.1.0" + }, + "peerDependenciesMeta": { + "encoding": { + "optional": true + } + } + }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/object-inspect": { + "version": "1.13.4", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", + "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", + "license": "MIT", + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "license": "MIT", + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/openapi-types": { + "version": "12.1.3", + "resolved": "https://registry.npmjs.org/openapi-types/-/openapi-types-12.1.3.tgz", + "integrity": "sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==", + "license": "MIT" + }, + "node_modules/p-limit": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", + "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "yocto-queue": "^0.1.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", + "license": "BlueOak-1.0.0" + }, + "node_modules/parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/path-key": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", + "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/path-to-regexp": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-8.3.0.tgz", + "integrity": "sha512-7jdwVIRtsP8MYpdXSwOS0YdD0Du+qOoF/AEPIt88PcCFrZCzx41oxku1jD88hZBwbNUIEfpqvuhjFaMAqMTWnA==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/pkce-challenge": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.0.tgz", + "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==", + "license": "MIT", + "engines": { + "node": ">=16.20.0" + } + }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "license": "MIT", + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + }, + "node_modules/punycode": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", + "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/qs": { + "version": "6.14.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz", + "integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==", + "license": "BSD-3-Clause", + "dependencies": { + "side-channel": "^1.1.0" + }, + "engines": { + "node": ">=0.6" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-3.0.1.tgz", + "integrity": "sha512-9G8cA+tuMS75+6G/TzW8OtLzmBDMo8p1JRxN5AZ+LAp8uxGA8V8GZm4GQ4/N5QNQEnLmg6SS7wyuSmbKepiKqA==", + "license": "MIT", + "dependencies": { + "bytes": "3.1.2", + "http-errors": "2.0.0", + "iconv-lite": "0.7.0", + "unpipe": "1.0.0" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/readable-stream": { + "version": "3.6.2", + "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz", + "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", + "license": "MIT", + "peer": true, + "dependencies": { + "inherits": "^2.0.3", + "string_decoder": "^1.1.1", + "util-deprecate": "^1.0.1" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/retry": { + "version": "0.13.1", + "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", + "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">= 4" + } + }, + "node_modules/retry-request": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/retry-request/-/retry-request-7.0.2.tgz", + "integrity": "sha512-dUOvLMJ0/JJYEn8NrpOaGNE7X3vpI5XlZS/u0ANjqtcZVKnIxP7IgCFwrKTxENw29emmwug53awKtaMm4i9g5w==", + "license": "MIT", + "peer": true, + "dependencies": { + "@types/request": "^2.48.8", + "extend": "^3.0.2", + "teeny-request": "^9.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/rimraf": { + "version": "5.0.10", + "resolved": "https://registry.npmjs.org/rimraf/-/rimraf-5.0.10.tgz", + "integrity": "sha512-l0OE8wL34P4nJH/H2ffoaniAokM2qSmrtXHmlpvYr5AVVX8msAyW0l8NVJFDxlSK4u3Uh/f41cQheDVdnYijwQ==", + "license": "ISC", + "dependencies": { + "glob": "^10.3.7" + }, + "bin": { + "rimraf": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/router": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/router/-/router-2.2.0.tgz", + "integrity": "sha512-nLTrUKm2UyiL7rlhapu/Zl45FwNgkZGaCpZbIHajDYgwlJCOzLSk+cIPAnsEqV955GjILJnKbdQC1nVPz+gAYQ==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "depd": "^2.0.0", + "is-promise": "^4.0.0", + "parseurl": "^1.3.3", + "path-to-regexp": "^8.0.0" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/safe-buffer": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", + "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/feross" + }, + { + "type": "patreon", + "url": "https://www.patreon.com/feross" + }, + { + "type": "consulting", + "url": "https://feross.org/support" + } + ], + "license": "MIT" + }, + "node_modules/safer-buffer": { + "version": "2.1.2", + "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", + "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", + "license": "MIT" + }, + "node_modules/send": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/send/-/send-1.2.0.tgz", + "integrity": "sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==", + "license": "MIT", + "dependencies": { + "debug": "^4.3.5", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "mime-types": "^3.0.1", + "ms": "^2.1.3", + "on-finished": "^2.4.1", + "range-parser": "^1.2.1", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/serve-static": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-2.2.0.tgz", + "integrity": "sha512-61g9pCh0Vnh7IutZjtLGGpTA355+OPn2TyDv/6ivP2h/AdAVX9azsoxmg2/M6nZeQZNYBEwIcsne1mJd9oQItQ==", + "license": "MIT", + "dependencies": { + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "parseurl": "^1.3.3", + "send": "^1.2.0" + }, + "engines": { + "node": ">= 18" + } + }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "license": "ISC" + }, + "node_modules/shebang-command": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", + "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", + "license": "MIT", + "dependencies": { + "shebang-regex": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/shebang-regex": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", + "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/side-channel": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", + "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3", + "side-channel-list": "^1.0.0", + "side-channel-map": "^1.0.1", + "side-channel-weakmap": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-list": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", + "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", + "license": "MIT", + "dependencies": { + "es-errors": "^1.3.0", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-map": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", + "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/side-channel-weakmap": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", + "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", + "license": "MIT", + "dependencies": { + "call-bound": "^1.0.2", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.5", + "object-inspect": "^1.13.3", + "side-channel-map": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, + "node_modules/statuses": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/stream-events": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz", + "integrity": "sha512-E1GUzBSgvct8Jsb3v2X15pjzN1tYebtbLaMg+eBOUOAxgbLoSbT2NS91ckc5lJD1KfLjId+jXJRgo0qnV5Nerg==", + "license": "MIT", + "peer": true, + "dependencies": { + "stubs": "^3.0.0" + } + }, + "node_modules/stream-shift": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.3.tgz", + "integrity": "sha512-76ORR0DO1o1hlKwTbi/DM3EXWGf3ZJYO8cXX5RJwnul2DEg2oyoZyjLNoQM8WsvZiFKCRfC1O0J7iCvie3RZmQ==", + "license": "MIT", + "peer": true + }, + "node_modules/string_decoder": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", + "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", + "license": "MIT", + "peer": true, + "dependencies": { + "safe-buffer": "~5.2.0" + } + }, + "node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "license": "MIT", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.2.tgz", + "integrity": "sha512-gmBGslpoQJtgnMAvOVqGZpEz9dyoKTCzy2nfz/n8aIFhN/jCE/rCmcxabB6jOOHV+0WNnylOxaxBQPSvcWklhA==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^6.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/strip-ansi?sponsor=1" + } + }, + "node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/strnum": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/strnum/-/strnum-1.1.2.tgz", + "integrity": "sha512-vrN+B7DBIoTTZjnPNewwhx6cBA/H+IS7rfW68n7XxC1y7uoiGQBxaKzqucGUgavX15dJgiGztLJ8vxuEzwqBdA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/NaturalIntelligence" + } + ], + "license": "MIT", + "peer": true + }, + "node_modules/stubs": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/stubs/-/stubs-3.0.0.tgz", + "integrity": "sha512-PdHt7hHUJKxvTCgbKX9C1V/ftOcjJQgz8BZwNfV5c4B6dcGqlpelTbJ999jBGZ2jYiPAwcX5dP6oBwVlBlUbxw==", + "license": "MIT", + "peer": true + }, + "node_modules/teeny-request": { + "version": "9.0.0", + "resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-9.0.0.tgz", + "integrity": "sha512-resvxdc6Mgb7YEThw6G6bExlXKkv6+YbuzGg9xuXxSgxJF7Ozs+o8Y9+2R3sArdWdW8nOokoQb1yrpFB0pQK2g==", + "license": "Apache-2.0", + "peer": true, + "dependencies": { + "http-proxy-agent": "^5.0.0", + "https-proxy-agent": "^5.0.0", + "node-fetch": "^2.6.9", + "stream-events": "^1.0.5", + "uuid": "^9.0.0" + }, + "engines": { + "node": ">=14" + } + }, + "node_modules/teeny-request/node_modules/agent-base": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", + "integrity": "sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ==", + "license": "MIT", + "peer": true, + "dependencies": { + "debug": "4" + }, + "engines": { + "node": ">= 6.0.0" + } + }, + "node_modules/teeny-request/node_modules/https-proxy-agent": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz", + "integrity": "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA==", + "license": "MIT", + "peer": true, + "dependencies": { + "agent-base": "6", + "debug": "4" + }, + "engines": { + "node": ">= 6" + } + }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "license": "MIT", + "engines": { + "node": ">=0.6" + } + }, + "node_modules/tr46": { + "version": "0.0.3", + "resolved": "https://registry.npmjs.org/tr46/-/tr46-0.0.3.tgz", + "integrity": "sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==", + "license": "MIT" + }, + "node_modules/type-is": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-2.0.1.tgz", + "integrity": "sha512-OZs6gsjF4vMp32qrCbiVSkrFmXtG/AZhY3t0iAMrMBiAZyV9oALtXO8hsrHbMXF9x6L3grlFuwW2oAz7cav+Gw==", + "license": "MIT", + "dependencies": { + "content-type": "^1.0.5", + "media-typer": "^1.1.0", + "mime-types": "^3.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/undici-types": { + "version": "7.16.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.16.0.tgz", + "integrity": "sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==", + "license": "MIT", + "peer": true + }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/uri-js": { + "version": "4.4.1", + "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", + "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==", + "license": "BSD-2-Clause", + "dependencies": { + "punycode": "^2.1.0" + } + }, + "node_modules/util-deprecate": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", + "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", + "license": "MIT", + "peer": true + }, + "node_modules/uuid": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.1.tgz", + "integrity": "sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, + "node_modules/vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/web-streams-polyfill": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/web-streams-polyfill/-/web-streams-polyfill-3.3.3.tgz", + "integrity": "sha512-d2JWLCivmZYTSIoge9MsgFCZrt571BikcWGYkjC1khllbTeDlGqZ2D8vD8E/lJa8WGWbb7Plm8/XJYV7IJHZZw==", + "license": "MIT", + "engines": { + "node": ">= 8" + } + }, + "node_modules/webidl-conversions": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/webidl-conversions/-/webidl-conversions-3.0.1.tgz", + "integrity": "sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==", + "license": "BSD-2-Clause" + }, + "node_modules/whatwg-url": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/whatwg-url/-/whatwg-url-5.0.0.tgz", + "integrity": "sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==", + "license": "MIT", + "dependencies": { + "tr46": "~0.0.3", + "webidl-conversions": "^3.0.0" + } + }, + "node_modules/which": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", + "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "license": "ISC", + "dependencies": { + "isexe": "^2.0.0" + }, + "bin": { + "node-which": "bin/node-which" + }, + "engines": { + "node": ">= 8" + } + }, + "node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.18.3", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", + "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/yocto-queue": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", + "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.24.6", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.6.tgz", + "integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.24.1" + } + } + } +} diff --git a/docs/en/getting-started/quickstart/js/adk/package.json b/docs/en/getting-started/quickstart/js/adk/package.json new file mode 100644 index 0000000000..ad3e081cb9 --- /dev/null +++ b/docs/en/getting-started/quickstart/js/adk/package.json @@ -0,0 +1,17 @@ +{ + "name": "adk", + "version": "1.0.0", + "description": "", + "main": "quickstart.js", + "type": "module", + "scripts": { + "test": "node --test" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "@google/adk": "^0.1.3", + "@toolbox-sdk/adk": "^0.1.5" + } +} diff --git a/docs/en/getting-started/quickstart/js/adk/quickstart.js b/docs/en/getting-started/quickstart/js/adk/quickstart.js new file mode 100644 index 0000000000..3e6645f6ba --- /dev/null +++ b/docs/en/getting-started/quickstart/js/adk/quickstart.js @@ -0,0 +1,56 @@ +import { InMemoryRunner, LlmAgent, LogLevel } from '@google/adk'; +import { ToolboxClient } from '@toolbox-sdk/adk'; + +const prompt = ` +You're a helpful hotel assistant. You handle hotel searching, booking, and +cancellations. When the user searches for a hotel, mention its name, id, +location and price tier. Always mention hotel ids while performing any +searches. This is very important for any operations. For any bookings or +cancellations, please provide the appropriate confirmation. Be sure to +update checkin or checkout dates if mentioned by the user. +Don't ask for confirmations from the user. +`; + +const queries = [ + "Find hotels with Basel in its name.", + "Can you book the Hilton Basel for me?", + "Oh wait, this is too expensive. Please cancel it and book the Hyatt Regency instead.", + "My check in dates would be from April 10, 2024 to April 19, 2024.", +]; + +process.env.GOOGLE_GENAI_API_KEY = process.env.GOOGLE_API_KEY || 'your-api-key'; // Replace it with your API key + +export async function main() { + const userId = 'test_user'; + const client = new ToolboxClient('http://127.0.0.1:5000'); + const tools = await client.loadToolset("my-toolset"); + + const rootAgent = new LlmAgent({ + name: 'hotel_agent', + model: 'gemini-2.5-flash', + description: 'Agent for hotel bookings and administration.', + instruction: prompt, + tools: tools, + }); + + const appName = rootAgent.name; + const runner = new InMemoryRunner({ agent: rootAgent, appName, logLevel: LogLevel.ERROR, }); + const session = await runner.sessionService.createSession({ appName, userId }); + + for (const query of queries) { + await runPrompt(runner, userId, session.id, query); + } +} + +async function runPrompt(runner, userId, sessionId, prompt) { + const content = { role: 'user', parts: [{ text: prompt }] }; + const stream = runner.runAsync({ userId, sessionId, newMessage: content }); + const responses = await Array.fromAsync(stream); + const accumulatedResponse = responses + .flatMap((e) => e.content?.parts?.map((p) => p.text) ?? []) + .join(''); + + console.log(`\nMODEL RESPONSE: ${accumulatedResponse}\n`); +} + +main(); \ No newline at end of file diff --git a/docs/en/getting-started/quickstart/js/langchain/package-lock.json b/docs/en/getting-started/quickstart/js/langchain/package-lock.json index 7cece23be2..7c52d6e598 100644 --- a/docs/en/getting-started/quickstart/js/langchain/package-lock.json +++ b/docs/en/getting-started/quickstart/js/langchain/package-lock.json @@ -872,11 +872,12 @@ } }, "node_modules/jws": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/jws/-/jws-4.0.0.tgz", - "integrity": "sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==", + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/jws/-/jws-4.0.1.tgz", + "integrity": "sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA==", + "license": "MIT", "dependencies": { - "jwa": "^2.0.0", + "jwa": "^2.0.1", "safe-buffer": "^5.0.1" } }, diff --git a/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt b/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt index 82f1615a38..c065d5dad7 100644 --- a/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt +++ b/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt @@ -1,4 +1,4 @@ -llama-index==0.14.8 +llama-index==0.14.10 llama-index-llms-google-genai==0.7.3 toolbox-llamaindex==0.5.3 pytest==9.0.1 diff --git a/docs/en/getting-started/quickstart/shared/configure_toolbox.md b/docs/en/getting-started/quickstart/shared/configure_toolbox.md index c5f6485825..dda247e2ef 100644 --- a/docs/en/getting-started/quickstart/shared/configure_toolbox.md +++ b/docs/en/getting-started/quickstart/shared/configure_toolbox.md @@ -13,7 +13,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/how-to/connect-ide/cloud_sql_mssql_admin_mcp.md b/docs/en/how-to/connect-ide/cloud_sql_mssql_admin_mcp.md index 58fdd8cf4b..03abf8422e 100644 --- a/docs/en/how-to/connect-ide/cloud_sql_mssql_admin_mcp.md +++ b/docs/en/how-to/connect-ide/cloud_sql_mssql_admin_mcp.md @@ -52,6 +52,7 @@ instance, database and users: * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` ## Install MCP Toolbox @@ -297,6 +298,7 @@ instances and interacting with your database: * **list_databases**: Lists all databases for a Cloud SQL instance. * **create_user**: Creates a new user in a Cloud SQL instance. * **wait_for_operation**: Waits for a Cloud SQL operation to complete. +* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance. {{< notice note >}} Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs diff --git a/docs/en/how-to/connect-ide/cloud_sql_mysql_admin_mcp.md b/docs/en/how-to/connect-ide/cloud_sql_mysql_admin_mcp.md index f6287262bc..85eb041213 100644 --- a/docs/en/how-to/connect-ide/cloud_sql_mysql_admin_mcp.md +++ b/docs/en/how-to/connect-ide/cloud_sql_mysql_admin_mcp.md @@ -52,6 +52,7 @@ database and users: * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` ## Install MCP Toolbox @@ -297,6 +298,7 @@ instances and interacting with your database: * **list_databases**: Lists all databases for a Cloud SQL instance. * **create_user**: Creates a new user in a Cloud SQL instance. * **wait_for_operation**: Waits for a Cloud SQL operation to complete. +* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance. {{< notice note >}} Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs diff --git a/docs/en/how-to/connect-ide/cloud_sql_pg_admin_mcp.md b/docs/en/how-to/connect-ide/cloud_sql_pg_admin_mcp.md index 29f6e3791f..fabeb0d8be 100644 --- a/docs/en/how-to/connect-ide/cloud_sql_pg_admin_mcp.md +++ b/docs/en/how-to/connect-ide/cloud_sql_pg_admin_mcp.md @@ -52,6 +52,7 @@ instance, database and users: * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` ## Install MCP Toolbox @@ -297,6 +298,7 @@ instances and interacting with your database: * **list_databases**: Lists all databases for a Cloud SQL instance. * **create_user**: Creates a new user in a Cloud SQL instance. * **wait_for_operation**: Waits for a Cloud SQL operation to complete. +* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance. {{< notice note >}} Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs diff --git a/docs/en/how-to/connect-ide/looker_mcp.md b/docs/en/how-to/connect-ide/looker_mcp.md index 64cf1f3440..82b8575d1c 100644 --- a/docs/en/how-to/connect-ide/looker_mcp.md +++ b/docs/en/how-to/connect-ide/looker_mcp.md @@ -18,6 +18,7 @@ to expose your developer assistant tools to a Looker instance: * [Cline][cline] (VS Code extension) * [Claude desktop][claudedesktop] * [Claude code][claudecode] +* [Antigravity][antigravity] [toolbox]: https://github.com/googleapis/genai-toolbox [gemini-cli]: #configure-your-mcp-client @@ -27,6 +28,7 @@ to expose your developer assistant tools to a Looker instance: [cline]: #configure-your-mcp-client [claudedesktop]: #configure-your-mcp-client [claudecode]: #configure-your-mcp-client +[antigravity]: #connect-with-antigravity ## Set up Looker @@ -38,6 +40,55 @@ to expose your developer assistant tools to a Looker instance: listening at a different port, and you will need to use `https://looker.example.com:19999` instead. +## Connect with Antigravity + +You can connect Looker to Antigravity in the following ways: + +* Using the MCP Store +* Using a custom configuration + +{{< notice note >}} +You don't need to download the MCP Toolbox binary to use these methods. +{{< /notice >}} + +{{< tabpane text=true >}} +{{% tab header="MCP Store" lang="en" %}} +The most straightforward way to connect to Looker in Antigravity is by using the built-in MCP Store. + +1. Open Antigravity and open the editor's agent panel. +1. Click the **"..."** icon at the top of the panel and select **MCP Servers**. +1. Locate **Looker** in the list of available servers and click Install. +1. Follow the on-screen prompts to securely link your accounts where applicable. + +After you install Looker in the MCP Store, resources and tools from the server are automatically available to the editor. + +{{% /tab %}} +{{% tab header="Custom config" lang="en" %}} + To connect to a custom MCP server, follow these steps: + +1. Open Antigravity and navigate to the MCP store using the **"..."** drop-down at the top of the editor's agent panel. +1. To open the **mcp_config.json** file, click **MCP Servers** and then click **Manage MCP Servers > View raw config**. +1. Add the following configuration, replace the environment variables with your values, and save. + + ```json + { + "mcpServers": { + "looker": { + "command": "npx", + "args": ["-y", "@toolbox-sdk/server", "--prebuilt", "looker", "--stdio"], + "env": { + "LOOKER_BASE_URL": "https://looker.example.com", + "LOOKER_CLIENT_ID": "your-client-id", + "LOOKER_CLIENT_SECRET": "your-client-secret" + } + } + } + } + ``` + +{{% /tab %}} +{{< /tabpane >}} + ## Install MCP Toolbox 1. Download the latest version of Toolbox as a binary. Select the [correct @@ -49,19 +100,19 @@ to expose your developer assistant tools to a Looker instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} @@ -290,7 +341,7 @@ assistant to list models, explores, dimensions, and measures. Run a query, retrieve the SQL for a query, and run a saved Look. The full tool list is available in the [Prebuilt Tools -Reference](../../reference/prebuilt-tools/#looker). +Reference](../../reference/prebuilt-tools.md/#looker). The following tools are available to the LLM: @@ -323,6 +374,8 @@ instance and create new saved content. data 1. **make_dashboard**: Create a saved dashboard in Looker and return the URL 1. **add_dashboard_element**: Add a tile to a dashboard +1. **add_dashboard_filter**: Add a filter to a dashboard +1. **generate_embed_url**: Generate an embed url for content ### Looker Instance Health Tools diff --git a/docs/en/how-to/connect-ide/mssql_mcp.md b/docs/en/how-to/connect-ide/mssql_mcp.md index b965182349..defb5f0e18 100644 --- a/docs/en/how-to/connect-ide/mssql_mcp.md +++ b/docs/en/how-to/connect-ide/mssql_mcp.md @@ -45,19 +45,19 @@ instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/mysql_mcp.md b/docs/en/how-to/connect-ide/mysql_mcp.md index da609cdce7..0d8d5a1ba5 100644 --- a/docs/en/how-to/connect-ide/mysql_mcp.md +++ b/docs/en/how-to/connect-ide/mysql_mcp.md @@ -43,19 +43,19 @@ expose your developer assistant tools to a MySQL instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/neo4j_mcp.md b/docs/en/how-to/connect-ide/neo4j_mcp.md index 98a8742dda..56795aef0f 100644 --- a/docs/en/how-to/connect-ide/neo4j_mcp.md +++ b/docs/en/how-to/connect-ide/neo4j_mcp.md @@ -44,19 +44,19 @@ expose your developer assistant tools to a Neo4j instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/postgres_mcp.md b/docs/en/how-to/connect-ide/postgres_mcp.md index ed9ffd9490..6ec92b948e 100644 --- a/docs/en/how-to/connect-ide/postgres_mcp.md +++ b/docs/en/how-to/connect-ide/postgres_mcp.md @@ -56,19 +56,19 @@ Omni](https://cloud.google.com/alloydb/omni/current/docs/overview). {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/connect-ide/sqlite_mcp.md b/docs/en/how-to/connect-ide/sqlite_mcp.md index 25ad412283..1493a71885 100644 --- a/docs/en/how-to/connect-ide/sqlite_mcp.md +++ b/docs/en/how-to/connect-ide/sqlite_mcp.md @@ -43,19 +43,19 @@ to expose your developer assistant tools to a SQLite instance: {{< tabpane persist=header >}} {{< tab header="linux/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/linux/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/linux/amd64/toolbox {{< /tab >}} {{< tab header="darwin/arm64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/arm64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/arm64/toolbox {{< /tab >}} {{< tab header="darwin/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/darwin/amd64/toolbox +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/darwin/amd64/toolbox {{< /tab >}} {{< tab header="windows/amd64" lang="bash" >}} -curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/windows/amd64/toolbox.exe +curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/windows/amd64/toolbox.exe {{< /tab >}} {{< /tabpane >}} diff --git a/docs/en/how-to/deploy_docker.md b/docs/en/how-to/deploy_docker.md index f7f6ac8827..ff12367572 100644 --- a/docs/en/how-to/deploy_docker.md +++ b/docs/en/how-to/deploy_docker.md @@ -67,6 +67,13 @@ networks: ``` + {{< notice tip >}} +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. E.g. `command: [ "toolbox", +"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0", +"--allowed-origins", "https://foo.bar"]` +{{< /notice >}} + 1. Run the following command to bring up the Toolbox and Postgres instance ```bash diff --git a/docs/en/how-to/deploy_gke.md b/docs/en/how-to/deploy_gke.md index d717dff787..4c18e9bfb1 100644 --- a/docs/en/how-to/deploy_gke.md +++ b/docs/en/how-to/deploy_gke.md @@ -188,6 +188,12 @@ description: > path: tools.yaml ``` + {{< notice tip >}} +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. E.g. `args: ["--address", +"0.0.0.0", "--allowed-origins", "https://foo.bar"]` +{{< /notice >}} + 1. Create the deployment. ```bash diff --git a/docs/en/how-to/deploy_toolbox.md b/docs/en/how-to/deploy_toolbox.md index a5941362f6..455f6bd3ff 100644 --- a/docs/en/how-to/deploy_toolbox.md +++ b/docs/en/how-to/deploy_toolbox.md @@ -104,7 +104,7 @@ section. export IMAGE=us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:latest ``` -{{< notice note >}} + {{< notice note >}} **The `$PORT` Environment Variable** Google Cloud Run dictates the port your application must listen on by setting the `$PORT` environment variable inside your container. This value defaults to @@ -140,6 +140,45 @@ deployment will time out. # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud ``` +### Update deployed server to be secure + +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. In order to do that, you will +have to re-deploy the cloud run service with the new flag. + +1. Set an environment variable to the cloud run url: + + ```bash + export URL= + ``` + +2. Redeploy Toolbox: + + ```bash + gcloud run deploy toolbox \ + --image $IMAGE \ + --service-account toolbox-identity \ + --region us-central1 \ + --set-secrets "/app/tools.yaml=tools:latest" \ + --args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL" + # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud + ``` + + If you are using a VPC network, use the command below: + + ```bash + gcloud run deploy toolbox \ + --image $IMAGE \ + --service-account toolbox-identity \ + --region us-central1 \ + --set-secrets "/app/tools.yaml=tools:latest" \ + --args="--tools-file=/app/tools.yaml","--address=0.0.0.0","--port=8080","--allowed-origins=$URL" \ + # TODO(dev): update the following to match your VPC if necessary + --network default \ + --subnet default + # --allow-unauthenticated # https://cloud.google.com/run/docs/authenticating/public#gcloud + ``` + ## Connecting with Toolbox Client SDK You can connect to Toolbox Cloud Run instances directly through the SDK. diff --git a/docs/en/how-to/export_telemetry.md b/docs/en/how-to/export_telemetry.md index 0265ce27fb..f9d8c88404 100644 --- a/docs/en/how-to/export_telemetry.md +++ b/docs/en/how-to/export_telemetry.md @@ -79,12 +79,16 @@ There are a couple of steps to run and use a Collector. ``` 1. Run toolbox with the `--telemetry-otlp` flag. Configure it to send them to - `http://127.0.0.1:4553` (for HTTP) or the Collector's URL. + `127.0.0.1:4553` (for HTTP) or the Collector's URL. ```bash - ./toolbox --telemetry-otlp=http://127.0.0.1:4553 + ./toolbox --telemetry-otlp=127.0.0.1:4553 ``` + {{< notice tip >}} + To pass an insecure endpoint, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. + {{< /notice >}} + 1. Once telemetry datas are collected, you can view them in your telemetry backend. If you are using GCP exporters, telemetry will be visible in GCP dashboard at [Metrics Explorer][metrics-explorer] and [Trace diff --git a/docs/en/reference/cli.md b/docs/en/reference/cli.md index 91c4f2b240..1c9829995e 100644 --- a/docs/en/reference/cli.md +++ b/docs/en/reference/cli.md @@ -16,15 +16,16 @@ description: > | | `--log-level` | Specify the minimum level logged. Allowed: 'DEBUG', 'INFO', 'WARN', 'ERROR'. | `info` | | | `--logging-format` | Specify logging format to use. Allowed: 'standard' or 'JSON'. | `standard` | | `-p` | `--port` | Port the server will listen on. | `5000` | -| | `--prebuilt` | Use a prebuilt tool configuration by source type. Cannot be used with --tools-file. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | | +| | `--prebuilt` | Use a prebuilt tool configuration by source type. See [Prebuilt Tools Reference](prebuilt-tools.md) for allowed values. | | | | `--stdio` | Listens via MCP STDIO instead of acting as a remote HTTP server. | | | | `--telemetry-gcp` | Enable exporting directly to Google Cloud Monitoring. | | | | `--telemetry-otlp` | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. 'http://127.0.0.1:4318') | | | | `--telemetry-service-name` | Sets the value of the service.name resource attribute for telemetry data. | `toolbox` | -| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --prebuilt, --tools-files, or --tools-folder. | | -| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --prebuilt, --tools-file, or --tools-folder. | | -| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --prebuilt, --tools-file, or --tools-files. | | +| | `--tools-file` | File path specifying the tool configuration. Cannot be used with --tools-files or --tools-folder. | | +| | `--tools-files` | Multiple file paths specifying tool configurations. Files will be merged. Cannot be used with --tools-file or --tools-folder. | | +| | `--tools-folder` | Directory path containing YAML tool configuration files. All .yaml and .yml files in the directory will be loaded and merged. Cannot be used with --tools-file or --tools-files. | | | | `--ui` | Launches the Toolbox UI web server. | | +| | `--allowed-origins` | Specifies a list of origins permitted to access this server. | `*` | | `-v` | `--version` | version for toolbox | | ## Examples @@ -45,6 +46,9 @@ description: > ```bash # Basic server with custom port configuration ./toolbox --tools-file "tools.yaml" --port 8080 + +# Server with prebuilt + custom tools configurations +./toolbox --tools-file tools.yaml --prebuilt alloydb-postgres ``` ### Tool Configuration Sources @@ -71,8 +75,8 @@ The CLI supports multiple mutually exclusive ways to specify tool configurations {{< notice tip >}} The CLI enforces mutual exclusivity between configuration source flags, -preventing simultaneous use of `--prebuilt` with file-based options, and -ensuring only one of `--tools-file`, `--tools-files`, or `--tools-folder` is +preventing simultaneous use of the file-based options ensuring only one of +`--tools-file`, `--tools-files`, or `--tools-folder` is used at a time. {{< /notice >}} diff --git a/docs/en/reference/prebuilt-tools.md b/docs/en/reference/prebuilt-tools.md index 61313851b8..b340ac055a 100644 --- a/docs/en/reference/prebuilt-tools.md +++ b/docs/en/reference/prebuilt-tools.md @@ -13,6 +13,12 @@ allowing developers to interact with and take action on databases. See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. +{{< notice tip >}} +You can now use `--prebuilt` along `--tools-file`, `--tools-files`, or +`--tools-folder` to combine prebuilt configs with custom tools. +See [Usage Examples](../reference/cli.md#examples). +{{< /notice >}} + ## AlloyDB Postgres * `--prebuilt` value: `alloydb-postgres` @@ -50,6 +56,12 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_triggers`: Lists triggers in the database. * `list_indexes`: List available user indexes in a PostgreSQL database. * `list_sequences`: List sequences in a PostgreSQL database. + * `list_publication_tables`: List publication tables in a PostgreSQL database. + * `list_tablespaces`: Lists tablespaces in the database. + * `list_pg_settings`: List configuration parameters for the PostgreSQL server. + * `list_database_stats`: Lists the key performance and activity statistics for + each database in the AlloyDB instance. + * `list_roles`: Lists all the user-created roles in PostgreSQL database. ## AlloyDB Postgres Admin @@ -178,6 +190,8 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` + * **Tools:** * `create_instance`: Creates a new Cloud SQL for MySQL instance. * `get_instance`: Gets information about a Cloud SQL instance. @@ -186,6 +200,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_databases`: Lists all databases for a Cloud SQL instance. * `create_user`: Creates a new user in a Cloud SQL instance. * `wait_for_operation`: Waits for a Cloud SQL operation to complete. + * `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance. ## Cloud SQL for PostgreSQL @@ -224,6 +239,12 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_triggers`: Lists triggers in the database. * `list_indexes`: List available user indexes in a PostgreSQL database. * `list_sequences`: List sequences in a PostgreSQL database. + * `list_publication_tables`: List publication tables in a PostgreSQL database. + * `list_tablespaces`: Lists tablespaces in the database. + * `list_pg_settings`: List configuration parameters for the PostgreSQL server. + * `list_database_stats`: Lists the key performance and activity statistics for + each database in the postgreSQL instance. + * `list_roles`: Lists all the user-created roles in PostgreSQL database. ## Cloud SQL for PostgreSQL Observability @@ -257,6 +278,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` * **Tools:** * `create_instance`: Creates a new Cloud SQL for PostgreSQL instance. * `get_instance`: Gets information about a Cloud SQL instance. @@ -265,6 +287,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_databases`: Lists all databases for a Cloud SQL instance. * `create_user`: Creates a new user in a Cloud SQL instance. * `wait_for_operation`: Waits for a Cloud SQL operation to complete. + * `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance. ## Cloud SQL for SQL Server @@ -316,6 +339,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * All `editor` and `viewer` tools * `create_instance` * `create_user` + * `clone_instance` * **Tools:** * `create_instance`: Creates a new Cloud SQL for SQL Server instance. * `get_instance`: Gets information about a Cloud SQL instance. @@ -324,6 +348,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_databases`: Lists all databases for a Cloud SQL instance. * `create_user`: Creates a new user in a Cloud SQL instance. * `wait_for_operation`: Waits for a Cloud SQL operation to complete. + * `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance. ## Dataplex @@ -397,6 +422,8 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `run_dashboard`: Runs the queries associated with a dashboard. * `make_dashboard`: Creates a new dashboard. * `add_dashboard_element`: Adds a tile to a dashboard. + * `add_dashboard_filter`: Adds a filter to a dashboard. + * `generate_embed_url`: Generate an embed url for content. * `health_pulse`: Test the health of a Looker instance. * `health_analyze`: Analyze the LookML usage of a Looker instance. * `health_vacuum`: Suggest LookML elements that can be removed. @@ -443,8 +470,8 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `--prebuilt` value: `mssql` * **Environment Variables:** - * `MSSQL_HOST`: The hostname or IP address of the SQL Server instance. - * `MSSQL_PORT`: The port number for the SQL Server instance. + * `MSSQL_HOST`: (Optional) The hostname or IP address of the SQL Server instance. + * `MSSQL_PORT`: (Optional) The port number for the SQL Server instance. * `MSSQL_DATABASE`: The name of the database to connect to. * `MSSQL_USER`: The database username. * `MSSQL_PASSWORD`: The password for the database user. @@ -525,6 +552,12 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `list_triggers`: Lists triggers in the database. * `list_indexes`: List available user indexes in a PostgreSQL database. * `list_sequences`: List sequences in a PostgreSQL database. + * `list_publication_tables`: List publication tables in a PostgreSQL database. + * `list_tablespaces`: Lists tablespaces in the database. + * `list_pg_settings`: List configuration parameters for the PostgreSQL server. + * `list_database_stats`: Lists the key performance and activity statistics for + each database in the PostgreSQL server. + * `list_roles`: Lists all the user-created roles in PostgreSQL database. ## Google Cloud Serverless for Apache Spark @@ -556,6 +589,7 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP. * `execute_sql`: Executes a DML SQL query. * `execute_sql_dql`: Executes a DQL SQL query. * `list_tables`: Lists tables in the database. + * `list_graphs`: Lists graphs in the database. ## Spanner (PostgreSQL dialect) diff --git a/docs/en/resources/sources/alloydb-pg.md b/docs/en/resources/sources/alloydb-pg.md index 1b65dd540d..5ead2db770 100644 --- a/docs/en/resources/sources/alloydb-pg.md +++ b/docs/en/resources/sources/alloydb-pg.md @@ -77,6 +77,25 @@ cluster][alloydb-free-trial]. - [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md) List cardinality of columns in a table in a PostgreSQL database. +- [`postgres-list-table-stats`](../tools/postgres/postgres-list-table-stats.md) + List statistics of a table in a PostgreSQL database. + +- [`postgres-list-publication-tables`](../tools/postgres/postgres-list-publication-tables.md) + List publication tables in a PostgreSQL database. + +- [`postgres-list-tablespaces`](../tools/postgres/postgres-list-tablespaces.md) + List tablespaces in an AlloyDB for PostgreSQL database. + +- [`postgres-list-pg-settings`](../tools/postgres/postgres-list-pg-settings.md) + List configuration parameters for the PostgreSQL server. + +- [`postgres-list-database-stats`](../tools/postgres/postgres-list-database-stats.md) + Lists the key performance and activity statistics for each database in the AlloyDB + instance. + +- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md) + Lists all the user-created roles in PostgreSQL database.. + ### Pre-built Configurations - [AlloyDB using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/alloydb_pg_mcp/) diff --git a/docs/en/resources/sources/cloud-gda.md b/docs/en/resources/sources/cloud-gda.md new file mode 100644 index 0000000000..dc400f17e8 --- /dev/null +++ b/docs/en/resources/sources/cloud-gda.md @@ -0,0 +1,40 @@ +--- +title: "Gemini Data Analytics" +type: docs +weight: 1 +description: > + A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API. +aliases: + - /resources/sources/cloud-gemini-data-analytics +--- + +## About + +The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API. + +Authentication can be handled in two ways: + +1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method. +2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls. + +## Example + +```yaml +sources: + my-gda-source: + kind: cloud-gemini-data-analytics + projectId: my-project-id + + my-oauth-gda-source: + kind: cloud-gemini-data-analytics + projectId: my-project-id + useClientOAuth: true +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| kind | string | true | Must be "cloud-gemini-data-analytics". | +| projectId | string | true | The Google Cloud Project ID where the API is enabled. | +| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. | diff --git a/docs/en/resources/sources/cloud-sql-mysql.md b/docs/en/resources/sources/cloud-sql-mysql.md index 93b2f71b41..e9f89f22a9 100644 --- a/docs/en/resources/sources/cloud-sql-mysql.md +++ b/docs/en/resources/sources/cloud-sql-mysql.md @@ -31,6 +31,9 @@ to a database by following these instructions][csql-mysql-quickstart]. - [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md) List active queries in Cloud SQL for MySQL. +- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md) + Provide information about how MySQL executes a SQL statement (EXPLAIN). + - [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md) List tables in a Cloud SQL for MySQL database. @@ -88,13 +91,40 @@ mTLS. [public-ip]: https://cloud.google.com/sql/docs/mysql/configure-ip [conn-overview]: https://cloud.google.com/sql/docs/mysql/connect-overview -### Database User +### Authentication -Currently, this source only uses standard authentication. You will need to [create -a MySQL user][cloud-sql-users] to login to the database with. +This source supports both password-based authentication and IAM +authentication (using your [Application Default Credentials][adc]). + +#### Standard Authentication + +To connect using user/password, [create +a MySQL user][cloud-sql-users] and input your credentials in the `user` and +`password` fields. + +```yaml +user: ${USER_NAME} +password: ${PASSWORD} +``` [cloud-sql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users +#### IAM Authentication + +To connect using IAM authentication: + +1. Prepare your database instance and user following this [guide][iam-guide]. +2. You could choose one of the two ways to log in: + - Specify your IAM email as the `user`. + - Leave your `user` field blank. Toolbox will fetch the [ADC][adc] + automatically and log in using the email associated with it. + +3. Leave the `password` field blank. + +[iam-guide]: https://cloud.google.com/sql/docs/mysql/iam-logins +[cloudsql-users]: https://cloud.google.com/sql/docs/mysql/create-manage-users + + ## Example ```yaml @@ -124,6 +154,6 @@ instead of hardcoding your secrets into the configuration file. | region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). | | instance | string | true | Name of the Cloud SQL instance within the cluster (e.g. "my-instance"). | | database | string | true | Name of the MySQL database to connect to (e.g. "my_db"). | -| user | string | true | Name of the MySQL user to connect as (e.g. "my-pg-user"). | -| password | string | true | Password of the MySQL user (e.g. "my-password"). | +| user | string | false | Name of the MySQL user to connect as (e.g "my-mysql-user"). Defaults to IAM auth using [ADC][adc] email if unspecified. | +| password | string | false | Password of the MySQL user (e.g. "my-password"). Defaults to attempting IAM authentication if unspecified. | | ipType | string | false | IP Type of the Cloud SQL instance, must be either `public`, `private`, or `psc`. Default: `public`. | diff --git a/docs/en/resources/sources/cloud-sql-pg.md b/docs/en/resources/sources/cloud-sql-pg.md index bd18ae942b..880f996942 100644 --- a/docs/en/resources/sources/cloud-sql-pg.md +++ b/docs/en/resources/sources/cloud-sql-pg.md @@ -58,6 +58,7 @@ to a database by following these instructions][csql-pg-quickstart]. - [`postgres-list-sequences`](../tools/postgres/postgres-list-sequences.md) List sequences in a PostgreSQL database. + - [`postgres-long-running-transactions`](../tools/postgres/postgres-long-running-transactions.md) List long running transactions in a PostgreSQL database. @@ -73,6 +74,25 @@ to a database by following these instructions][csql-pg-quickstart]. - [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md) List cardinality of columns in a table in a PostgreSQL database. +- [`postgres-list-table-stats`](../tools/postgres/postgres-list-table-stats.md) + List statistics of a table in a PostgreSQL database. + +- [`postgres-list-publication-tables`](../tools/postgres/postgres-list-publication-tables.md) + List publication tables in a PostgreSQL database. + +- [`postgres-list-tablespaces`](../tools/postgres/postgres-list-tablespaces.md) + List tablespaces in a PostgreSQL database. + +- [`postgres-list-pg-settings`](../tools/postgres/postgres-list-pg-settings.md) + List configuration parameters for the PostgreSQL server. + +- [`postgres-list-database-stats`](../tools/postgres/postgres-list-database-stats.md) + Lists the key performance and activity statistics for each database in the postgreSQL + instance. + +- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md) + Lists all the user-created roles in PostgreSQL database.. + ### Pre-built Configurations - [Cloud SQL for Postgres using diff --git a/docs/en/resources/sources/looker.md b/docs/en/resources/sources/looker.md index b6400d906a..75bebf37ea 100644 --- a/docs/en/resources/sources/looker.md +++ b/docs/en/resources/sources/looker.md @@ -91,18 +91,17 @@ instead of hardcoding your secrets into the configuration file. ## Reference -| **field** | **type** | **required** | **description** | -|----------------------|:--------:|:------------:|-------------------------------------------------------------------------------------------| -| kind | string | true | Must be "looker". | -| base_url | string | true | The URL of your Looker server with no trailing /. | -| client_id | string | false | The client id assigned by Looker. | -| client_secret | string | false | The client secret assigned by Looker. | -| verify_ssl | string | false | Whether to check the ssl certificate of the server. | -| project | string | false | The project id to use in Google Cloud. | -| location | string | false | The location to use in Google Cloud. (default: us) | -| timeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, 120s is applied. | -| use_client_oauth | string | false | Use OAuth tokens instead of client_id and client_secret. (default: false) If a header | -| | | | name is provided, it will be used instead of "Authorization". | -| show_hidden_models | string | false | Show or hide hidden models. (default: true) | -| show_hidden_explores | string | false | Show or hide hidden explores. (default: true) | -| show_hidden_fields | string | false | Show or hide hidden fields. (default: true) | \ No newline at end of file +| **field** | **type** | **required** | **description** | +|----------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "looker". | +| base_url | string | true | The URL of your Looker server with no trailing /. | +| client_id | string | false | The client id assigned by Looker. | +| client_secret | string | false | The client secret assigned by Looker. | +| verify_ssl | string | false | Whether to check the ssl certificate of the server. | +| project | string | false | The project id to use in Google Cloud. | +| location | string | false | The location to use in Google Cloud. (default: us) | +| timeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, 120s is applied. | +| use_client_oauth | string | false | Use OAuth tokens instead of client_id and client_secret. (default: false) If a header name is provided, it will be used instead of "Authorization". | +| show_hidden_models | string | false | Show or hide hidden models. (default: true) | +| show_hidden_explores | string | false | Show or hide hidden explores. (default: true) | +| show_hidden_fields | string | false | Show or hide hidden fields. (default: true) | \ No newline at end of file diff --git a/docs/en/resources/sources/mariadb.md b/docs/en/resources/sources/mariadb.md new file mode 100644 index 0000000000..c956274fde --- /dev/null +++ b/docs/en/resources/sources/mariadb.md @@ -0,0 +1,78 @@ +--- +title: "MariaDB" +type: docs +weight: 1 +description: > + MariaDB is an open-source relational database compatible with MySQL. + +--- +## About + +MariaDB is a relational database management system derived from MySQL. It +implements the MySQL protocol and client libraries and supports modern SQL +features with a focus on performance and reliability. + +**Note**: MariaDB is supported using the MySQL source. +## Available Tools + +- [`mysql-sql`](../tools/mysql/mysql-sql.md) + Execute pre-defined prepared SQL queries in MariaDB. + +- [`mysql-execute-sql`](../tools/mysql/mysql-execute-sql.md) + Run parameterized SQL queries in MariaDB. + +- [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md) + List active queries in MariaDB. + +- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md) + List tables in a MariaDB database. + +- [`mysql-list-tables-missing-unique-indexes`](../tools/mysql/mysql-list-tables-missing-unique-indexes.md) + List tables in a MariaDB database that do not have primary or unique indices. + +- [`mysql-list-table-fragmentation`](../tools/mysql/mysql-list-table-fragmentation.md) + List table fragmentation in MariaDB tables. + +## Requirements + +### Database User + +This source only uses standard authentication. You will need to [create a +MariaDB user][mariadb-users] to log in to the database. + +[mariadb-users]: https://mariadb.com/kb/en/create-user/ + +## Example + +```yaml +sources: + my_mariadb_db: + kind: mysql + host: 127.0.0.1 + port: 3306 + database: my_db + user: ${MARIADB_USER} + password: ${MARIADB_PASS} + # Optional TLS and other driver parameters. For example, enable preferred TLS: + # queryParams: + # tls: preferred + queryTimeout: 30s # Optional: query timeout duration +``` + +{{< notice tip >}} +Use environment variables instead of committing credentials to source files. +{{< /notice >}} + + +## Reference + +| **field** | **type** | **required** | **description** | +| ------------ | :------: | :----------: | ----------------------------------------------------------------------------------------------- | +| kind | string | true | Must be `mysql`. | +| host | string | true | IP address to connect to (e.g. "127.0.0.1"). | +| port | string | true | Port to connect to (e.g. "3307"). | +| database | string | true | Name of the MariaDB database to connect to (e.g. "my_db"). | +| user | string | true | Name of the MariaDB user to connect as (e.g. "my-mysql-user"). | +| password | string | true | Password of the MariaDB user (e.g. "my-password"). | +| queryTimeout | string | false | Maximum time to wait for query execution (e.g. "30s", "2m"). By default, no timeout is applied. | +| queryParams | map | false | Arbitrary DSN parameters passed to the driver (e.g. `tls: preferred`, `charset: utf8mb4`). Useful for enabling TLS or other connection options. | diff --git a/docs/en/resources/sources/mysql.md b/docs/en/resources/sources/mysql.md index 44d46195ac..95f2b96d7c 100644 --- a/docs/en/resources/sources/mysql.md +++ b/docs/en/resources/sources/mysql.md @@ -25,6 +25,9 @@ reliability, performance, and ease of use. - [`mysql-list-active-queries`](../tools/mysql/mysql-list-active-queries.md) List active queries in MySQL. +- [`mysql-get-query-plan`](../tools/mysql/mysql-get-query-plan.md) + Provide information about how MySQL executes a SQL statement (EXPLAIN). + - [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md) List tables in a MySQL database. diff --git a/docs/en/resources/sources/oracle.md b/docs/en/resources/sources/oracle.md index 4932ea6e22..51fa18fe13 100644 --- a/docs/en/resources/sources/oracle.md +++ b/docs/en/resources/sources/oracle.md @@ -18,10 +18,10 @@ DW) database workloads. ## Available Tools - [`oracle-sql`](../tools/oracle/oracle-sql.md) - Execute pre-defined prepared SQL queries in Oracle. + Execute pre-defined prepared SQL queries in Oracle. - [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md) - Run parameterized SQL queries in Oracle. + Run parameterized SQL queries in Oracle. ## Requirements @@ -33,6 +33,25 @@ user][oracle-users] to log in to the database with the necessary permissions. [oracle-users]: https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html +### Oracle Driver Requirement (Conditional) + +The Oracle source offers two connection drivers: + +1. **Pure Go Driver (`useOCI: false`, default):** Uses the `go-ora` library. + This driver is simpler and does not require any local Oracle software + installation, but it **lacks support for advanced features** like Oracle + Wallets or Kerberos authentication. + +2. **OCI-Based Driver (`useOCI: true`):** Uses the `godror` library, which + provides access to **advanced Oracle features** like Digital Wallet support. + +If you set `useOCI: true`, you **must** install the **Oracle Instant Client** +libraries on the machine where this tool runs. + +You can download the Instant Client from the official Oracle website: [Oracle +Instant Client +Downloads](https://www.oracle.com/database/technologies/instant-client/downloads.html) + ## Connection Methods You can configure the connection to your Oracle database using one of the @@ -66,12 +85,15 @@ using a TNS (Transparent Network Substrate) alias. containing it. This setting will override the `TNS_ADMIN` environment variable. -## Example +## Examples + +This example demonstrates the four connection methods you could choose from: ```yaml sources: my-oracle-source: kind: oracle + # --- Choose one connection method --- # 1. Host, Port, and Service Name host: 127.0.0.1 @@ -88,6 +110,43 @@ sources: user: ${USER_NAME} password: ${PASSWORD} + # Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client) +``` + +### Using an Oracle Wallet + +Oracle Wallet allows you to store credentails used for database connection. Depending whether you are using an OCI-based driver, the wallet configuration is different. + +#### Pure Go Driver (`useOCI: false`) - Oracle Wallet + +The `go-ora` driver uses the `walletLocation` field to connect to a database secured with an Oracle Wallet without standard username and password. + +```yaml +sources: + pure-go-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + # The TNS Alias is often required to connect to a service registered in tnsnames.ora + tnsAlias: "SECURE_DB_ALIAS" + walletLocation: "/path/to/my/wallet/directory" +``` + +#### OCI-Based Driver (`useOCI: true`) - Oracle Wallet + +For the OCI-based driver, wallet authentication is triggered by setting tnsAdmin to the wallet directory and connecting via a tnsAlias. + +```yaml +sources: + oci-wallet: + kind: oracle + connectionString: "127.0.0.1:1521/XEPDB1" + user: ${USER_NAME} + password: ${PASSWORD} + tnsAlias: "WALLET_DB_ALIAS" + tnsAdmin: "/opt/oracle/wallet" # Directory containing tnsnames.ora, sqlnet.ora, and wallet files + useOCI: true ``` {{< notice tip >}} @@ -97,14 +156,15 @@ instead of hardcoding your secrets into the configuration file. ## Reference -| **field** | **type** | **required** | **description** | -|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------| -| kind | string | true | Must be "oracle". | -| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | -| password | string | true | Password of the Oracle user (e.g. "my-password"). | -| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | -| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | -| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | -| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | -| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | -| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| **field** | **type** | **required** | **description** | +|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "oracle". | +| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). | +| password | string | true | Password of the Oracle user (e.g. "my-password"). | +| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. | +| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. | +| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. | +| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. | +| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. | +| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. | +| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). | diff --git a/docs/en/resources/sources/postgres.md b/docs/en/resources/sources/postgres.md index 8422eadd81..de19be506a 100644 --- a/docs/en/resources/sources/postgres.md +++ b/docs/en/resources/sources/postgres.md @@ -68,6 +68,25 @@ reputation for reliability, feature robustness, and performance. - [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md) List cardinality of columns in a table in a PostgreSQL database. +- [`postgres-list-table-stats`](../tools/postgres/postgres-list-table-stats.md) + List statistics of a table in a PostgreSQL database. + +- [`postgres-list-publication-tables`](../tools/postgres/postgres-list-publication-tables.md) + List publication tables in a PostgreSQL database. + +- [`postgres-list-tablespaces`](../tools/postgres/postgres-list-tablespaces.md) + List tablespaces in a PostgreSQL database. + +- [`postgres-list-pg-settings`](../tools/postgres/postgres-list-pg-settings.md) + List configuration parameters for the PostgreSQL server. + +- [`postgres-list-database-stats`](../tools/postgres/postgres-list-database-stats.md) + Lists the key performance and activity statistics for each database in the postgreSQL + server. + +- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md) + Lists all the user-created roles in PostgreSQL database.. + ### Pre-built Configurations - [PostgreSQL using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/postgres_mcp/) diff --git a/docs/en/resources/sources/serverless-spark.md b/docs/en/resources/sources/serverless-spark.md index 0d137d36b7..1f2afc3cec 100644 --- a/docs/en/resources/sources/serverless-spark.md +++ b/docs/en/resources/sources/serverless-spark.md @@ -21,6 +21,10 @@ Apache Spark. Get a Serverless Spark batch. - [`serverless-spark-cancel-batch`](../tools/serverless-spark/serverless-spark-cancel-batch.md) Cancel a running Serverless Spark batch operation. +- [`serverless-spark-create-pyspark-batch`](../tools/serverless-spark/serverless-spark-create-pyspark-batch.md) + Create a Serverless Spark PySpark batch operation. +- [`serverless-spark-create-spark-batch`](../tools/serverless-spark/serverless-spark-create-spark-batch.md) + Create a Serverless Spark Java batch operation. ## Requirements diff --git a/docs/en/resources/tools/cloudgda/_index.md b/docs/en/resources/tools/cloudgda/_index.md new file mode 100644 index 0000000000..63e1189632 --- /dev/null +++ b/docs/en/resources/tools/cloudgda/_index.md @@ -0,0 +1,7 @@ +--- +title: "Gemini Data Analytics" +type: docs +weight: 1 +description: > + Tools for Gemini Data Analytics. +--- diff --git a/docs/en/resources/tools/cloudgda/cloud-gda-query.md b/docs/en/resources/tools/cloudgda/cloud-gda-query.md new file mode 100644 index 0000000000..faf119d6e6 --- /dev/null +++ b/docs/en/resources/tools/cloudgda/cloud-gda-query.md @@ -0,0 +1,92 @@ +--- +title: "Gemini Data Analytics QueryData" +type: docs +weight: 1 +description: > + A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API. +aliases: + - /resources/tools/cloud-gemini-data-analytics-query +--- + +## About + +The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases). + +## Example + +```yaml +tools: + my-gda-query-tool: + kind: cloud-gemini-data-analytics-query + source: my-gda-source + description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations." + location: ${your_database_location} + context: + datasourceReferences: + cloudSqlReference: + databaseReference: + projectId: "${your_project_id}" + region: "${your_database_instance_region}" + instanceId: "${your_database_instance_id}" + databaseId: "${your_database_name}" + engine: "POSTGRESQL" + agentContextReference: + contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id} + generationOptions: + generateQueryResult: true + generateNaturalLanguageAnswer: true + generateExplanation: true + generateDisambiguationQuestion: true +``` + +### Usage Flow + +When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration. + +The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results). + +See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details. + +**Example Input Prompt:** + +```text +How many accounts who have region in Prague are eligible for loans? A3 contains the data of region. +``` + +**Example API Response:** + +```json +{ + "generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'", + "intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.", + "naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.", + "queryResult": { + "columns": [ + { + "type": "INT64" + } + ], + "rows": [ + { + "values": [ + { + "value": "84" + } + ] + } + ], + "totalRowCount": "1" + } +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| kind | string | true | Must be "cloud-gemini-data-analytics-query". | +| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. | +| description | string | true | A description of the tool's purpose. | +| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. | +| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. | +| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. | diff --git a/docs/en/resources/tools/cloudsql/cloudsqlcloneinstance.md b/docs/en/resources/tools/cloudsql/cloudsqlcloneinstance.md new file mode 100644 index 0000000000..89fdb8d986 --- /dev/null +++ b/docs/en/resources/tools/cloudsql/cloudsqlcloneinstance.md @@ -0,0 +1,68 @@ +--- +title: cloud-sql-clone-instance +type: docs +weight: 10 +description: "Clone a Cloud SQL instance." +--- + +The `cloud-sql-clone-instance` tool clones a Cloud SQL instance using the Cloud SQL Admin API. + +{{< notice info dd>}} +This tool uses a `source` of kind `cloud-sql-admin`. +{{< /notice >}} + +## Examples + +Basic clone (current state) + +```yaml +tools: + clone-instance-basic: + kind: cloud-sql-clone-instance + source: cloud-sql-admin-source + description: "Creates an exact copy of a Cloud SQL instance. Supports configuring instance zones and high-availability setup through zone preferences." +``` + +Point-in-time recovery (PITR) clone + +```yaml +tools: + clone-instance-pitr: + kind: cloud-sql-clone-instance + source: cloud-sql-admin-source + description: "Creates an exact copy of a Cloud SQL instance at a specific point in time (PITR). Supports configuring instance zones and high-availability setup through zone preferences" +``` + +## Reference + +### Tool Configuration + +| **field** | **type** | **required** | **description** | +| -------------- | :------: | :----------: | ------------------------------------------------------------- | +| kind | string | true | Must be "cloud-sql-clone-instance". | +| source | string | true | The name of the `cloud-sql-admin` source to use. | +| description | string | false | A description of the tool. | + +### Tool Inputs + +| **parameter** | **type** | **required** | **description** | +| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- | +| project | string | true | The project ID. | +| sourceInstanceName | string | true | The name of the source instance to clone. | +| destinationInstanceName | string | true | The name of the new (cloned) instance. | +| pointInTime | string | false | (Optional) The point in time for a PITR (Point-In-Time Recovery) clone. | +| preferredZone | string | false | (Optional) The preferred zone for the cloned instance. If not specified, defaults to the source instance's zone. | +| preferredSecondaryZone | string | false | (Optional) The preferred secondary zone for the cloned instance (for HA). | + +## Usage Notes + +- The tool supports both basic clone and point-in-time recovery (PITR) clone operations. +- For PITR, specify the `pointInTime` parameter in RFC3339 format (e.g., `2024-01-01T00:00:00Z`). +- The source must be a valid Cloud SQL Admin API source. +- You can optionally specify the `zone` parameter to set the zone for the cloned instance. If omitted, the zone of the source instance will be used. +- You can optionally specify the `preferredZone` and `preferredSecondaryZone` (only in REGIONAL instances) to set the preferred zones for the cloned instance. These are useful for high availability (HA) configurations. If omitted, defaults will be used based on the source instance. + +## See Also +- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api) +- [Toolbox Cloud SQL tools documentation](../cloudsql) +- [Cloud SQL Clone API documentation](https://cloud.google.com/sql/docs/mysql/clone-instance) \ No newline at end of file diff --git a/docs/en/resources/tools/looker/looker-add-dashboard-element.md b/docs/en/resources/tools/looker/looker-add-dashboard-element.md index 64f7cf39de..3c0a65f2d3 100644 --- a/docs/en/resources/tools/looker/looker-add-dashboard-element.md +++ b/docs/en/resources/tools/looker/looker-add-dashboard-element.md @@ -10,27 +10,18 @@ aliases: ## About -The `looker-add-dashboard-element` creates a dashboard element -in the given dashboard. +The `looker-add-dashboard-element` tool creates a new tile (element) within an existing Looker dashboard. +Tiles are added in the order this tool is called for a given `dashboard_id`. + +CRITICAL ORDER OF OPERATIONS: +1. Create the dashboard using `make_dashboard`. +2. Add any dashboard-level filters using `add_dashboard_filter`. +3. Then, add elements (tiles) using this tool. It's compatible with the following sources: - [looker](../../sources/looker.md) -`looker-add-dashboard-element` takes eleven parameters: - -1. the `model` -2. the `explore` -3. the `fields` list -4. an optional set of `filters` -5. an optional set of `pivots` -6. an optional set of `sorts` -7. an optional `limit` -8. an optional `tz` -9. an optional `vis_config` -10. the `title` -11. the `dashboard_id` - ## Example ```yaml @@ -39,24 +30,37 @@ tools: kind: looker-add-dashboard-element source: looker-source description: | - add_dashboard_element Tool + This tool creates a new tile (element) within an existing Looker dashboard. + Tiles are added in the order this tool is called for a given `dashboard_id`. - This tool creates a new tile in a Looker dashboard using - the query parameters and the vis_config specified. + CRITICAL ORDER OF OPERATIONS: + 1. Create the dashboard using `make_dashboard`. + 2. Add any dashboard-level filters using `add_dashboard_filter`. + 3. Then, add elements (tiles) using this tool. - Most of the parameters are the same as the query_url - tool. In addition, there is a title that may be provided. - The dashboard_id must be specified. That is obtained - from calling make_dashboard. + Required Parameters: + - dashboard_id: The ID of the target dashboard, obtained from `make_dashboard`. + - model_name, explore_name, fields: These query parameters are inherited + from the `query` tool and are required to define the data for the tile. - This tool can be called many times for one dashboard_id - and the resulting tiles will be added in order. + Optional Parameters: + - title: An optional title for the dashboard tile. + - pivots, filters, sorts, limit, query_timezone: These query parameters are + inherited from the `query` tool and can be used to customize the tile's query. + - vis_config: A JSON object defining the visualization settings for this tile. + The structure and options are the same as for the `query_url` tool's `vis_config`. + + Connecting to Dashboard Filters: + A dashboard element can be connected to one or more dashboard filters (created with + `add_dashboard_filter`). To do this, specify the `name` of the dashboard filter + and the `field` from the element's query that the filter should apply to. + The format for specifying the field is `view_name.field_name`. ``` ## Reference | **field** | **type** | **required** | **description** | -|-------------|:--------:|:------------:|----------------------------------------------------| -| kind | string | true | Must be "looker-add-dashboard-element" | -| source | string | true | Name of the source the SQL should execute on. | -| description | string | true | Description of the tool that is passed to the LLM. | +|:------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "looker-add-dashboard-element". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | \ No newline at end of file diff --git a/docs/en/resources/tools/looker/looker-add-dashboard-filter.md b/docs/en/resources/tools/looker/looker-add-dashboard-filter.md new file mode 100644 index 0000000000..e5cf5ba34d --- /dev/null +++ b/docs/en/resources/tools/looker/looker-add-dashboard-filter.md @@ -0,0 +1,75 @@ +--- +title: "looker-add-dashboard-filter" +type: docs +weight: 1 +description: > + The "looker-add-dashboard-filter" tool adds a filter to a specified dashboard. +aliases: +- /resources/tools/looker-add-dashboard-filter +--- + +## About + +The `looker-add-dashboard-filter` tool adds a filter to a specified Looker dashboard. + +CRITICAL ORDER OF OPERATIONS: +1. Create a dashboard using `make_dashboard`. +2. Add all desired filters using this tool (`add_dashboard_filter`). +3. Finally, add dashboard elements (tiles) using `add_dashboard_element`. + +It's compatible with the following sources: + +- [looker](../../sources/looker.md) + +## Parameters + +| **parameter** | **type** | **required** | **default** | **description** | +|:----------------------|:--------:|:-----------------:|:--------------:|-------------------------------------------------------------------------------------------------------------------------------| +| dashboard_id | string | true | none | The ID of the dashboard to add the filter to, obtained from `make_dashboard`. | +| name | string | true | none | A unique internal identifier for the filter. This name is used later in `add_dashboard_element` to bind tiles to this filter. | +| title | string | true | none | The label displayed to users in the Looker UI. | +| filter_type | string | true | `field_filter` | The filter type of filter. Can be `date_filter`, `number_filter`, `string_filter`, or `field_filter`. | +| default_value | string | false | none | The initial value for the filter. | +| model | string | if `field_filter` | none | The name of the LookML model, obtained from `get_models`. | +| explore | string | if `field_filter` | none | The name of the explore within the model, obtained from `get_explores`. | +| dimension | string | if `field_filter` | none | The name of the field (e.g., `view_name.field_name`) to base the filter on, obtained from `get_dimensions`. | +| allow_multiple_values | boolean | false | true | The Dashboard Filter should allow multiple values | +| required | boolean | false | false | The Dashboard Filter is required to run dashboard | + +## Example + +```yaml +tools: + add_dashboard_filter: + kind: looker-add-dashboard-filter + source: looker-source + description: | + This tool adds a filter to a Looker dashboard. + + CRITICAL ORDER OF OPERATIONS: + 1. Create a dashboard using `make_dashboard`. + 2. Add all desired filters using this tool (`add_dashboard_filter`). + 3. Finally, add dashboard elements (tiles) using `add_dashboard_element`. + + Parameters: + - dashboard_id (required): The ID from `make_dashboard`. + - name (required): A unique internal identifier for the filter. You will use this `name` later in `add_dashboard_element` to bind tiles to this filter. + - title (required): The label displayed to users in the UI. + - filter_type (required): One of `date_filter`, `number_filter`, `string_filter`, or `field_filter`. + - default_value (optional): The initial value for the filter. + + Field Filters (`flter_type: field_filter`): + If creating a field filter, you must also provide: + - model + - explore + - dimension + The filter will inherit suggestions and type information from this LookML field. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|----------------------------------------------------| +| kind | string | true | Must be "looker-add-dashboard-filter". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | \ No newline at end of file diff --git a/docs/en/resources/tools/looker/looker-conversational-analytics.md b/docs/en/resources/tools/looker/looker-conversational-analytics.md index 0ce6aa90f9..150f347cf7 100644 --- a/docs/en/resources/tools/looker/looker-conversational-analytics.md +++ b/docs/en/resources/tools/looker/looker-conversational-analytics.md @@ -34,9 +34,10 @@ tools: kind: looker-conversational-analytics source: looker-source description: | - Use this tool to perform data analysis, get insights, - or answer complex questions about the contents of specific - Looker explores. + Use this tool to ask questions about your data using the Looker Conversational + Analytics API. You must provide a natural language query and a list of + 1 to 5 model and explore combinations (e.g. [{'model': 'the_model', 'explore': 'the_explore'}]). + Use the 'get_models' and 'get_explores' tools to discover available models and explores. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-create-project-file.md b/docs/en/resources/tools/looker/looker-create-project-file.md index f6c2644046..826dda98e9 100644 --- a/docs/en/resources/tools/looker/looker-create-project-file.md +++ b/docs/en/resources/tools/looker/looker-create-project-file.md @@ -27,13 +27,18 @@ tools: kind: looker-create-project-file source: looker-source description: | - create_project_file Tool + This tool creates a new LookML file within a specified project, populating + it with the provided content. - Given a project_id and a file path within the project, as well as the content - of a LookML file, this tool will create a new file within the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The desired path and filename for the new file within the project. + - content (required): The full LookML content to write into the new file. + + Output: + A confirmation message upon successful file creation. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-delete-project-file.md b/docs/en/resources/tools/looker/looker-delete-project-file.md index bb545003d7..e5bf06948d 100644 --- a/docs/en/resources/tools/looker/looker-delete-project-file.md +++ b/docs/en/resources/tools/looker/looker-delete-project-file.md @@ -26,13 +26,17 @@ tools: kind: looker-delete-project-file source: looker-source description: | - delete_project_file Tool + This tool permanently deletes a specified LookML file from within a project. + Use with caution, as this action cannot be undone through the API. - Given a project_id and a file path within the project, this tool will delete - the file from the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The exact path to the LookML file to delete within the project. + + Output: + A confirmation message upon successful file deletion. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-dev-mode.md b/docs/en/resources/tools/looker/looker-dev-mode.md index ed354da07e..9f69343ad5 100644 --- a/docs/en/resources/tools/looker/looker-dev-mode.md +++ b/docs/en/resources/tools/looker/looker-dev-mode.md @@ -27,10 +27,13 @@ tools: kind: looker-dev-mode source: looker-source description: | - dev_mode Tool + This tool allows toggling the Looker IDE session between Development Mode and Production Mode. + Development Mode enables making and testing changes to LookML projects. - Passing true to this tool switches the session to dev mode. Passing false to this tool switches the - session to production mode. + Parameters: + - enable (required): A boolean value. + - `true`: Switches the current session to Development Mode. + - `false`: Switches the current session to Production Mode. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-generate-embed-url.md b/docs/en/resources/tools/looker/looker-generate-embed-url.md index abf7adef9c..1e136165da 100644 --- a/docs/en/resources/tools/looker/looker-generate-embed-url.md +++ b/docs/en/resources/tools/looker/looker-generate-embed-url.md @@ -36,11 +36,17 @@ tools: kind: looker-generate-embed-url source: looker-source description: | - generate_embed_url Tool + This tool generates a signed, private embed URL for specific Looker content, + allowing users to access it directly. - This tool generates an embeddable URL for Looker content. - You need to provide the type of content (e.g., 'dashboards', 'looks', 'query-visualization') - and the ID of the content. + Parameters: + - type (required): The type of content to embed. Common values include: + - `dashboards` + - `looks` + - `explore` + - id (required): The unique identifier for the content. + - For dashboards and looks, use the numeric ID (e.g., "123"). + - For explores, use the format "model_name/explore_name". ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-connection-databases.md b/docs/en/resources/tools/looker/looker-get-connection-databases.md index 60ab6999f2..23611fc2a1 100644 --- a/docs/en/resources/tools/looker/looker-get-connection-databases.md +++ b/docs/en/resources/tools/looker/looker-get-connection-databases.md @@ -26,10 +26,16 @@ tools: kind: looker-get-connection-databases source: looker-source description: | - get_connection_databases Tool + This tool retrieves a list of databases available through a specified Looker connection. + This is only applicable for connections that support multiple databases. + Use `get_connections` to check if a connection supports multiple databases. - This tool will list the databases available from a connection if the connection - supports multiple databases. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + + Output: + A JSON array of strings, where each string is the name of an available database. + If the connection does not support multiple databases, an empty list or an error will be returned. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-connection-schemas.md b/docs/en/resources/tools/looker/looker-get-connection-schemas.md index 972b93f0b6..0ef34015c3 100644 --- a/docs/en/resources/tools/looker/looker-get-connection-schemas.md +++ b/docs/en/resources/tools/looker/looker-get-connection-schemas.md @@ -26,10 +26,16 @@ tools: kind: looker-get-connection-schemas source: looker-source description: | - get_connection_schemas Tool + This tool retrieves a list of database schemas available through a specified + Looker connection. - This tool will list the schemas available from a connection, filtered by - an optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - database (optional): An optional database name to filter the schemas. + Only applicable for connections that support multiple databases. + + Output: + A JSON array of strings, where each string is the name of an available schema. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-connection-table-columns.md b/docs/en/resources/tools/looker/looker-get-connection-table-columns.md index 855006dc1a..f4db6445fe 100644 --- a/docs/en/resources/tools/looker/looker-get-connection-table-columns.md +++ b/docs/en/resources/tools/looker/looker-get-connection-table-columns.md @@ -26,11 +26,20 @@ tools: kind: looker-get-connection-table-columns source: looker-source description: | - get_connection_table_columns Tool + This tool retrieves a list of columns for one or more specified tables within a + given database schema and connection. - This tool will list the columns available from a connection, for all the tables - given in a comma separated list of table names, filtered by the - schema name and optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - schema (required): The name of the schema where the tables reside, obtained from `get_connection_schemas`. + - tables (required): A comma-separated string of table names for which to retrieve columns + (e.g., "users,orders,products"), obtained from `get_connection_tables`. + - database (optional): The name of the database to filter by. Only applicable for connections + that support multiple databases (check with `get_connections`). + + Output: + A JSON array of objects, where each object represents a column and contains details + such as `table_name`, `column_name`, `data_type`, and `is_nullable`. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-connection-tables.md b/docs/en/resources/tools/looker/looker-get-connection-tables.md index 8844c4184e..86a2830cd9 100644 --- a/docs/en/resources/tools/looker/looker-get-connection-tables.md +++ b/docs/en/resources/tools/looker/looker-get-connection-tables.md @@ -27,10 +27,17 @@ tools: kind: looker-get-connection-tables source: looker-source description: | - get_connection_tables Tool + This tool retrieves a list of tables available within a specified database schema + through a Looker connection. - This tool will list the tables available from a connection, filtered by the - schema name and optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - schema (required): The name of the schema to list tables from, obtained from `get_connection_schemas`. + - database (optional): The name of the database to filter by. Only applicable for connections + that support multiple databases (check with `get_connections`). + + Output: + A JSON array of strings, where each string is the name of an available table. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-connections.md b/docs/en/resources/tools/looker/looker-get-connections.md index 45936847e8..c6c0159789 100644 --- a/docs/en/resources/tools/looker/looker-get-connections.md +++ b/docs/en/resources/tools/looker/looker-get-connections.md @@ -26,11 +26,18 @@ tools: kind: looker-get-connections source: looker-source description: | - get_connections Tool + This tool retrieves a list of all database connections configured in the Looker system. - This tool will list all the connections available in the Looker system, as - well as the dialect name, the default schema, the database if applicable, - and whether the connection supports multiple databases. + Parameters: + This tool takes no parameters. + + Output: + A JSON array of objects, each representing a database connection and including details such as: + - `name`: The connection's unique identifier. + - `dialect`: The database dialect (e.g., "mysql", "postgresql", "bigquery"). + - `default_schema`: The default schema for the connection. + - `database`: The associated database name (if applicable). + - `supports_multiple_databases`: A boolean indicating if the connection can access multiple databases. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-dashboards.md b/docs/en/resources/tools/looker/looker-get-dashboards.md index 82b92c323e..db5c9e532e 100644 --- a/docs/en/resources/tools/looker/looker-get-dashboards.md +++ b/docs/en/resources/tools/looker/looker-get-dashboards.md @@ -29,25 +29,29 @@ default to 100 and 0. ```yaml tools: - get_dashboards: - kind: looker-get-dashboards - source: looker-source - description: | - get_dashboards Tool - - This tool is used to search for saved dashboards in a Looker instance. - String search params use case-insensitive matching. String search - params can contain % and '_' as SQL LIKE pattern match wildcard - expressions. example="dan%" will match "danger" and "Danzig" but - not "David" example="D_m%" will match "Damage" and "dump". - - Most search params can accept "IS NULL" and "NOT NULL" as special - expressions to match or exclude (respectively) rows where the - column is null. - - The limit and offset are used to paginate the results. - - The result of the get_dashboards tool is a list of json objects. + get_dashboards: + kind: looker-get-dashboards + source: looker-source + description: | + This tool searches for saved dashboards in a Looker instance. It returns a list of JSON objects, each representing a dashboard. + + Search Parameters: + - title (optional): Filter by dashboard title (supports wildcards). + - folder_id (optional): Filter by the ID of the folder where the dashboard is saved. + - user_id (optional): Filter by the ID of the user who created the dashboard. + - description (optional): Filter by description content (supports wildcards). + - id (optional): Filter by specific dashboard ID. + - limit (optional): Maximum number of results to return. Defaults to a system limit. + - offset (optional): Starting point for pagination. + + String Search Behavior: + - Case-insensitive matching. + - Supports SQL LIKE pattern match wildcards: + - `%`: Matches any sequence of zero or more characters. (e.g., `"finan%"` matches "financial", "finance") + - `_`: Matches any single character. (e.g., `"s_les"` matches "sales") + - Special expressions for null checks: + - `"IS NULL"`: Matches dashboards where the field is null. + - `"NOT NULL"`: Excludes dashboards where the field is null. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-dimensions.md b/docs/en/resources/tools/looker/looker-get-dimensions.md index 66ab329902..17f3bb68f7 100644 --- a/docs/en/resources/tools/looker/looker-get-dimensions.md +++ b/docs/en/resources/tools/looker/looker-get-dimensions.md @@ -28,16 +28,20 @@ tools: kind: looker-get-dimensions source: looker-source description: | - The get_dimensions tool retrieves the list of dimensions defined in - an explore. + This tool retrieves a list of dimensions defined within a specific Looker explore. + Dimensions are non-aggregatable attributes or characteristics of your data + (e.g., product name, order date, customer city) that can be used for grouping, + filtering, or segmenting query results. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. - If this returns a suggestions field for a dimension, the contents of suggestions - can be used as filters for this field. If this returns a suggest_explore and - suggest_dimension, a query against that explore and dimension can be used to find - valid filters for this field. + Output Details: + - If a dimension includes a `suggestions` field, its contents are valid values + that can be used directly as filters for that dimension. + - If a `suggest_explore` and `suggest_dimension` are provided, you can query + that specified explore and dimension to retrieve a list of valid filter values. ``` diff --git a/docs/en/resources/tools/looker/looker-get-explores.md b/docs/en/resources/tools/looker/looker-get-explores.md index 66d7e65eba..d92942de9d 100644 --- a/docs/en/resources/tools/looker/looker-get-explores.md +++ b/docs/en/resources/tools/looker/looker-get-explores.md @@ -40,10 +40,13 @@ tools: kind: looker-get-explores source: looker-source description: | - The get_explores tool retrieves the list of explores defined in a LookML model - in the Looker system. + This tool retrieves a list of explores defined within a specific LookML model. + Explores represent a curated view of your data, typically joining several + tables together to allow for focused analysis on a particular subject area. + The output provides details like the explore's `name` and `label`. - It takes one parameter, the model_name looked up from get_models. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-filters.md b/docs/en/resources/tools/looker/looker-get-filters.md index a1babf5572..2657936fd6 100644 --- a/docs/en/resources/tools/looker/looker-get-filters.md +++ b/docs/en/resources/tools/looker/looker-get-filters.md @@ -24,15 +24,22 @@ It's compatible with the following sources: ```yaml tools: - get_dimensions: + get_filters: kind: looker-get-filters source: looker-source description: | - The get_filters tool retrieves the list of filters defined in - an explore. + This tool retrieves a list of "filter-only fields" defined within a specific + Looker explore. These are special fields defined in LookML specifically to + create user-facing filter controls that do not directly affect the `GROUP BY` + clause of the SQL query. They are often used in conjunction with liquid templating + to create dynamic queries. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Note: Regular dimensions and measures can also be used as filters in a query. + This tool *only* returns fields explicitly defined as `filter:` in LookML. + + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. ``` The response is a json array with the following elements: diff --git a/docs/en/resources/tools/looker/looker-get-looks.md b/docs/en/resources/tools/looker/looker-get-looks.md index f4b39474cc..06bc5f7856 100644 --- a/docs/en/resources/tools/looker/looker-get-looks.md +++ b/docs/en/resources/tools/looker/looker-get-looks.md @@ -34,21 +34,26 @@ tools: kind: looker-get-looks source: looker-source description: | - get_looks Tool + This tool searches for saved Looks (pre-defined queries and visualizations) + in a Looker instance. It returns a list of JSON objects, each representing a Look. - This tool is used to search for saved looks in a Looker instance. - String search params use case-insensitive matching. String search - params can contain % and '_' as SQL LIKE pattern match wildcard - expressions. example="dan%" will match "danger" and "Danzig" but - not "David" example="D_m%" will match "Damage" and "dump". + Search Parameters: + - title (optional): Filter by Look title (supports wildcards). + - folder_id (optional): Filter by the ID of the folder where the Look is saved. + - user_id (optional): Filter by the ID of the user who created the Look. + - description (optional): Filter by description content (supports wildcards). + - id (optional): Filter by specific Look ID. + - limit (optional): Maximum number of results to return. Defaults to a system limit. + - offset (optional): Starting point for pagination. - Most search params can accept "IS NULL" and "NOT NULL" as special - expressions to match or exclude (respectively) rows where the - column is null. - - The limit and offset are used to paginate the results. - - The result of the get_looks tool is a list of json objects. + String Search Behavior: + - Case-insensitive matching. + - Supports SQL LIKE pattern match wildcards: + - `%`: Matches any sequence of zero or more characters. (e.g., `"dan%"` matches "danger", "Danzig") + - `_`: Matches any single character. (e.g., `"D_m%"` matches "Damage", "dump") + - Special expressions for null checks: + - `"IS NULL"`: Matches Looks where the field is null. + - `"NOT NULL"`: Excludes Looks where the field is null. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-measures.md b/docs/en/resources/tools/looker/looker-get-measures.md index a6c1a0a000..7304031855 100644 --- a/docs/en/resources/tools/looker/looker-get-measures.md +++ b/docs/en/resources/tools/looker/looker-get-measures.md @@ -28,16 +28,19 @@ tools: kind: looker-get-measures source: looker-source description: | - The get_measures tool retrieves the list of measures defined in - an explore. + This tool retrieves a list of measures defined within a specific Looker explore. + Measures are aggregatable metrics (e.g., total sales, average price, count of users) + that are used for calculations and quantitative analysis in your queries. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. - If this returns a suggestions field for a measure, the contents of suggestions - can be used as filters for this field. If this returns a suggest_explore and - suggest_dimension, a query against that explore and dimension can be used to find - valid filters for this field. + Output Details: + - If a measure includes a `suggestions` field, its contents are valid values + that can be used directly as filters for that measure. + - If a `suggest_explore` and `suggest_dimension` are provided, you can query + that specified explore and dimension to retrieve a list of valid filter values. ``` diff --git a/docs/en/resources/tools/looker/looker-get-models.md b/docs/en/resources/tools/looker/looker-get-models.md index b025ccc6d5..81002cf3a2 100644 --- a/docs/en/resources/tools/looker/looker-get-models.md +++ b/docs/en/resources/tools/looker/looker-get-models.md @@ -26,9 +26,12 @@ tools: kind: looker-get-models source: looker-source description: | - The get_models tool retrieves the list of LookML models in the Looker system. + This tool retrieves a list of available LookML models in the Looker instance. + LookML models define the data structure and relationships that users can query. + The output includes details like the model's `name` and `label`, which are + essential for subsequent calls to tools like `get_explores` or `query`. - It takes no parameters. + This tool takes no parameters. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-parameters.md b/docs/en/resources/tools/looker/looker-get-parameters.md index 527b30a48d..f40398568d 100644 --- a/docs/en/resources/tools/looker/looker-get-parameters.md +++ b/docs/en/resources/tools/looker/looker-get-parameters.md @@ -28,11 +28,15 @@ tools: kind: looker-get-parameters source: looker-source description: | - The get_parameters tool retrieves the list of parameters defined in - an explore. + This tool retrieves a list of parameters defined within a specific Looker explore. + LookML parameters are dynamic input fields that allow users to influence query + behavior without directly modifying the underlying LookML. They are often used + with `liquid` templating to create flexible dashboards and reports, enabling + users to choose dimensions, measures, or other query components at runtime. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. ``` The response is a json array with the following elements: diff --git a/docs/en/resources/tools/looker/looker-get-project-file.md b/docs/en/resources/tools/looker/looker-get-project-file.md index be0c1bd6b5..440615efa4 100644 --- a/docs/en/resources/tools/looker/looker-get-project-file.md +++ b/docs/en/resources/tools/looker/looker-get-project-file.md @@ -26,10 +26,15 @@ tools: kind: looker-get-project-file source: looker-source description: | - get_project_file Tool + This tool retrieves the raw content of a specific LookML file from within a project. - Given a project_id and a file path within the project, this tool returns - the contents of the LookML file. + Parameters: + - project_id (required): The unique ID of the LookML project, obtained from `get_projects`. + - file_path (required): The path to the LookML file within the project, + typically obtained from `get_project_files`. + + Output: + The raw text content of the specified LookML file. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-project-files.md b/docs/en/resources/tools/looker/looker-get-project-files.md index 1b668ae821..48ea273228 100644 --- a/docs/en/resources/tools/looker/looker-get-project-files.md +++ b/docs/en/resources/tools/looker/looker-get-project-files.md @@ -26,10 +26,15 @@ tools: kind: looker-get-project-files source: looker-source description: | - get_project_files Tool + This tool retrieves a list of all LookML files within a specified project, + providing details about each file. - Given a project_id this tool returns the details about - the LookML files that make up that project. + Parameters: + - project_id (required): The unique ID of the LookML project, obtained from `get_projects`. + + Output: + A JSON array of objects, each representing a LookML file and containing + details such as `path`, `id`, `type`, and `git_status`. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-get-projects.md b/docs/en/resources/tools/looker/looker-get-projects.md index 3618753e4e..7c582eeae0 100644 --- a/docs/en/resources/tools/looker/looker-get-projects.md +++ b/docs/en/resources/tools/looker/looker-get-projects.md @@ -26,10 +26,16 @@ tools: kind: looker-get-projects source: looker-source description: | - get_projects Tool + This tool retrieves a list of all LookML projects available on the Looker instance. + It is useful for identifying projects before performing actions like retrieving + project files or making modifications. - This tool returns the project_id and project_name for - all the LookML projects on the looker instance. + Parameters: + This tool takes no parameters. + + Output: + A JSON array of objects, each containing the `project_id` and `project_name` + for a LookML project. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-health-analyze.md b/docs/en/resources/tools/looker/looker-health-analyze.md index a3e9e7a7f6..bc44d3f301 100644 --- a/docs/en/resources/tools/looker/looker-health-analyze.md +++ b/docs/en/resources/tools/looker/looker-health-analyze.md @@ -42,17 +42,18 @@ tools: kind: looker-health-analyze source: looker-source description: | - health-analyze Tool + This tool calculates the usage statistics for Looker projects, models, and explores. - This tool calculates the usage of projects, models and explores. + Parameters: + - action (required): The type of resource to analyze. Can be `"projects"`, `"models"`, or `"explores"`. + - project (optional): The specific project ID to analyze. + - model (optional): The specific model name to analyze. Requires `project` if used without `explore`. + - explore (optional): The specific explore name to analyze. Requires `model` if used. + - timeframe (optional): The lookback period in days for usage data. Defaults to `90` days. + - min_queries (optional): The minimum number of queries for a resource to be considered active. Defaults to `1`. - It accepts 6 parameters: - 1. `action`: can be "projects", "models", or "explores" - 2. `project`: the project to analyze (optional) - 3. `model`: the model to analyze (optional) - 4. `explore`: the explore to analyze (optional) - 5. `timeframe`: the lookback period in days, default is 90 - 6. `min_queries`: the minimum number of queries to consider a resource as active, default is 1 + Output: + The result is a JSON object containing usage metrics for the specified resources. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-health-pulse.md b/docs/en/resources/tools/looker/looker-health-pulse.md index cf11e7a74f..ccbc05be34 100644 --- a/docs/en/resources/tools/looker/looker-health-pulse.md +++ b/docs/en/resources/tools/looker/looker-health-pulse.md @@ -49,20 +49,22 @@ tools: kind: looker-health-pulse source: looker-source description: | - health-pulse Tool + This tool performs various health checks on a Looker instance. - This tool takes the pulse of a Looker instance by taking - one of the following actions: - 1. `check_db_connections`, - 2. `check_dashboard_performance`, - 3. `check_dashboard_errors`, - 4. `check_explore_performance`, - 5. `check_schedule_failures`, or - 6. `check_legacy_features` - - The `check_legacy_features` action is only available in Looker Core. If - it is called on a Looker Core instance, you will get a notice. That notice - should not be reported as an error. + Parameters: + - action (required): Specifies the type of health check to perform. + Choose one of the following: + - `check_db_connections`: Verifies database connectivity. + - `check_dashboard_performance`: Assesses dashboard loading performance. + - `check_dashboard_errors`: Identifies errors within dashboards. + - `check_explore_performance`: Evaluates explore query performance. + - `check_schedule_failures`: Reports on failed scheduled deliveries. + - `check_legacy_features`: Checks for the usage of legacy features. + + Note on `check_legacy_features`: + This action is exclusively available in Looker Core instances. If invoked + on a non-Looker Core instance, it will return a notice rather than an error. + This notice should be considered normal behavior and not an indication of an issue. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-health-vacuum.md b/docs/en/resources/tools/looker/looker-health-vacuum.md index 38a36a50b8..f4d635ccc5 100644 --- a/docs/en/resources/tools/looker/looker-health-vacuum.md +++ b/docs/en/resources/tools/looker/looker-health-vacuum.md @@ -39,20 +39,19 @@ tools: kind: looker-health-vacuum source: looker-source description: | - health-vacuum Tool + This tool identifies and suggests LookML models or explores that can be + safely removed due to inactivity or low usage. - This tool suggests models or explores that can removed - because they are unused. + Parameters: + - action (required): The type of resource to analyze for removal candidates. Can be `"models"` or `"explores"`. + - project (optional): The specific project ID to consider. + - model (optional): The specific model name to consider. Requires `project` if used without `explore`. + - explore (optional): The specific explore name to consider. Requires `model` if used. + - timeframe (optional): The lookback period in days to assess usage. Defaults to `90` days. + - min_queries (optional): The minimum number of queries for a resource to be considered active. Defaults to `1`. - It accepts 6 parameters: - 1. `action`: can be "models" or "explores" - 2. `project`: the project to vacuum (optional) - 3. `model`: the model to vacuum (optional) - 4. `explore`: the explore to vacuum (optional) - 5. `timeframe`: the lookback period in days, default is 90 - 6. `min_queries`: the minimum number of queries to consider a resource as active, default is 1 - - The result is a list of objects that are candidates for deletion. + Output: + A JSON array of objects, each representing a model or explore that is a candidate for deletion due to low usage. ``` | **field** | **type** | **required** | **description** | diff --git a/docs/en/resources/tools/looker/looker-make-dashboard.md b/docs/en/resources/tools/looker/looker-make-dashboard.md index dcaa45f137..048d42bef0 100644 --- a/docs/en/resources/tools/looker/looker-make-dashboard.md +++ b/docs/en/resources/tools/looker/looker-make-dashboard.md @@ -30,18 +30,19 @@ tools: kind: looker-make-dashboard source: looker-source description: | - make_dashboard Tool + This tool creates a new, empty dashboard in Looker. Dashboards are stored + in the user's personal folder, and the dashboard name must be unique. + After creation, use `add_dashboard_filter` to add filters and + `add_dashboard_element` to add content tiles. - This tool creates a new dashboard in Looker. The dashboard is - initially empty and the add_dashboard_element tool is used to - add content to the dashboard. + Required Parameters: + - title (required): A unique title for the new dashboard. + - description (required): A brief description of the dashboard's purpose. - The newly created dashboard will be created in the user's - personal folder in looker. The dashboard name must be unique. - - The result is a json document with a link to the newly - created dashboard and the id of the dashboard. Use the id - when calling add_dashboard_element. + Output: + A JSON object containing a link (`url`) to the newly created dashboard and + its unique `id`. This `dashboard_id` is crucial for subsequent calls to + `add_dashboard_filter` and `add_dashboard_element`. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-make-look.md b/docs/en/resources/tools/looker/looker-make-look.md index ce266b03f6..148f245532 100644 --- a/docs/en/resources/tools/looker/looker-make-look.md +++ b/docs/en/resources/tools/looker/looker-make-look.md @@ -40,20 +40,24 @@ tools: kind: looker-make-look source: looker-source description: | - make_look Tool + This tool creates a new Look (saved query with visualization) in Looker. + The Look will be saved in the user's personal folder, and its name must be unique. - This tool creates a new look in Looker, using the query - parameters and the vis_config specified. + Required Parameters: + - title: A unique title for the new Look. + - description: A brief description of the Look's purpose. + - model_name: The name of the LookML model (from `get_models`). + - explore_name: The name of the explore (from `get_explores`). + - fields: A list of field names (dimensions, measures, filters, or parameters) to include in the query. - Most of the parameters are the same as the query_url - tool. In addition, there is a title and a description - that must be provided. + Optional Parameters: + - pivots, filters, sorts, limit, query_timezone: These parameters are identical + to those described for the `query` tool. + - vis_config: A JSON object defining the visualization settings for the Look. + The structure and options are the same as for the `query_url` tool's `vis_config`. - The newly created look will be created in the user's - personal folder in looker. The look name must be unique. - - The result is a json document with a link to the newly - created look. + Output: + A JSON object containing a link (`url`) to the newly created Look, along with its `id` and `slug`. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-query-sql.md b/docs/en/resources/tools/looker/looker-query-sql.md index 06492e1ccd..064464ea5d 100644 --- a/docs/en/resources/tools/looker/looker-query-sql.md +++ b/docs/en/resources/tools/looker/looker-query-sql.md @@ -41,38 +41,17 @@ tools: kind: looker-query-sql source: looker-source description: | - Query SQL Tool + This tool generates the underlying SQL query that Looker would execute + against the database for a given set of parameters. It is useful for + understanding how Looker translates a request into SQL. - This tool is used to generate a sql query against the LookML model. The - model, explore, and fields list must be specified. Pivots, - filters and sorts are optional. + Parameters: + All parameters for this tool are identical to those of the `query` tool. + This includes `model_name`, `explore_name`, `fields` (required), + and optional parameters like `pivots`, `filters`, `sorts`, `limit`, and `query_timezone`. - The model can be found from the get_models tool. The explore - can be found from the get_explores tool passing in the model. - The fields can be found from the get_dimensions, get_measures, - get_filters, and get_parameters tools, passing in the model - and the explore. - - Provide a model_id and explore_name, then a list - of fields. Optionally a list of pivots can be provided. - The pivots must also be included in the fields list. - - Filters are provided as a map of {"field.id": "condition", - "field.id2": "condition2", ...}. Do not put the field.id in - quotes. Filter expressions can be found at - https://cloud.google.com/looker/docs/filter-expressions. - - Sorts can be specified like [ "field.id desc 0" ]. - - An optional row limit can be added. If not provided the limit - will default to 500. "-1" can be specified for unlimited. - - An optional query timezone can be added. The query_timezone to - will default to that of the workstation where this MCP server - is running, or Etc/UTC if that can't be determined. Not all - models support custom timezones. - - The result of the query tool is the sql string. + Output: + The result of this tool is the raw SQL text. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-query-url.md b/docs/en/resources/tools/looker/looker-query-url.md index c93d9d6ee3..af1f138509 100644 --- a/docs/en/resources/tools/looker/looker-query-url.md +++ b/docs/en/resources/tools/looker/looker-query-url.md @@ -37,17 +37,21 @@ tools: kind: looker-query-url source: looker-source description: | - Query URL Tool + This tool generates a shareable URL for a Looker query, allowing users to + explore the query further within the Looker UI. It returns the generated URL, + along with the `query_id` and `slug`. - This tool is used to generate the URL of a query in Looker. - The user can then explore the query further inside Looker. - The tool also returns the query_id and slug. The parameters - are the same as the query tool with an additional vis_config - parameter. + Parameters: + All query parameters (e.g., `model_name`, `explore_name`, `fields`, `pivots`, + `filters`, `sorts`, `limit`, `query_timezone`) are the same as the `query` tool. - The vis_config is optional. If provided, it will be used to - control the default visualization for the query. Here are - some notes on making visualizations. + Additionally, it accepts an optional `vis_config` parameter: + - vis_config (optional): A JSON object that controls the default visualization + settings for the generated query. + + vis_config Details: + The `vis_config` object supports a wide range of properties for various chart types. + Here are some notes on making visualizations. ### Cartesian Charts (Area, Bar, Column, Line, Scatter) diff --git a/docs/en/resources/tools/looker/looker-query.md b/docs/en/resources/tools/looker/looker-query.md index 7b13f53fe0..7ba3292763 100644 --- a/docs/en/resources/tools/looker/looker-query.md +++ b/docs/en/resources/tools/looker/looker-query.md @@ -41,38 +41,24 @@ tools: kind: looker-query source: looker-source description: | - Query Tool + This tool runs a query against a LookML model and returns the results in JSON format. - This tool is used to run a query against the LookML model. The - model, explore, and fields list must be specified. Pivots, - filters and sorts are optional. + Required Parameters: + - model_name: The name of the LookML model (from `get_models`). + - explore_name: The name of the explore (from `get_explores`). + - fields: A list of field names (dimensions, measures, filters, or parameters) to include in the query. - The model can be found from the get_models tool. The explore - can be found from the get_explores tool passing in the model. - The fields can be found from the get_dimensions, get_measures, - get_filters, and get_parameters tools, passing in the model - and the explore. + Optional Parameters: + - pivots: A list of fields to pivot the results by. These fields must also be included in the `fields` list. + - filters: A map of filter expressions, e.g., `{"view.field": "value", "view.date": "7 days"}`. + - Do not quote field names. + - Use `not null` instead of `-NULL`. + - If a value contains a comma, enclose it in single quotes (e.g., "'New York, NY'"). + - sorts: A list of fields to sort by, optionally including direction (e.g., `["view.field desc"]`). + - limit: Row limit (default 500). Use "-1" for unlimited. + - query_timezone: specific timezone for the query (e.g. `America/Los_Angeles`). - Provide a model_id and explore_name, then a list - of fields. Optionally a list of pivots can be provided. - The pivots must also be included in the fields list. - - Filters are provided as a map of {"field.id": "condition", - "field.id2": "condition2", ...}. Do not put the field.id in - quotes. Filter expressions can be found at - https://cloud.google.com/looker/docs/filter-expressions. - If the condition is a string that contains a comma, use a second - set of quotes. For example, {"user.city": "'New York, NY'"}. - - Sorts can be specified like [ "field.id desc 0" ]. - - An optional row limit can be added. If not provided the limit - will default to 500. "-1" can be specified for unlimited. - - An optional query timezone can be added. The query_timezone to - will default to that of the workstation where this MCP server - is running, or Etc/UTC if that can't be determined. Not all - models support custom timezones. + Note: Use `get_dimensions`, `get_measures`, `get_filters`, and `get_parameters` to find valid fields. The result of the query tool is JSON ``` diff --git a/docs/en/resources/tools/looker/looker-run-dashboard.md b/docs/en/resources/tools/looker/looker-run-dashboard.md index df4a504bd9..cc2c2072df 100644 --- a/docs/en/resources/tools/looker/looker-run-dashboard.md +++ b/docs/en/resources/tools/looker/looker-run-dashboard.md @@ -27,11 +27,15 @@ tools: kind: looker-run-dashboard source: looker-source description: | - run_dashboard Tool + This tool executes the queries associated with each tile in a specified dashboard + and returns the aggregated data in a JSON structure. - This tools runs the query associated with each tile in a dashboard - and returns the data in a JSON structure. It accepts the dashboard_id - as the parameter. + Parameters: + - dashboard_id (required): The unique identifier of the dashboard to run, + typically obtained from the `get_dashboards` tool. + + Output: + The data from all dashboard tiles is returned as a JSON object. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-run-look.md b/docs/en/resources/tools/looker/looker-run-look.md index 1a1c512e51..eb2f57eedb 100644 --- a/docs/en/resources/tools/looker/looker-run-look.md +++ b/docs/en/resources/tools/looker/looker-run-look.md @@ -27,11 +27,15 @@ tools: kind: looker-run-look source: looker-source description: | - run_look Tool + This tool executes the query associated with a saved Look and + returns the resulting data in a JSON structure. - This tool runs the query associated with a look and returns - the data in a JSON structure. It accepts the look_id as the - parameter. + Parameters: + - look_id (required): The unique identifier of the Look to run, + typically obtained from the `get_looks` tool. + + Output: + The query results are returned as a JSON object. ``` ## Reference diff --git a/docs/en/resources/tools/looker/looker-update-project-file.md b/docs/en/resources/tools/looker/looker-update-project-file.md index df007e7caf..af8cabd81b 100644 --- a/docs/en/resources/tools/looker/looker-update-project-file.md +++ b/docs/en/resources/tools/looker/looker-update-project-file.md @@ -27,13 +27,17 @@ tools: kind: looker-update-project-file source: looker-source description: | - update_project_file Tool + This tool modifies the content of an existing LookML file within a specified project. - Given a project_id and a file path within the project, as well as the content - of a LookML file, this tool will modify the file within the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The exact path to the LookML file to modify within the project. + - content (required): The new, complete LookML content to overwrite the existing file. + + Output: + A confirmation message upon successful file modification. ``` ## Reference diff --git a/docs/en/resources/tools/mindsdb/mindsdb-sql.md b/docs/en/resources/tools/mindsdb/mindsdb-sql.md index e129cbe71e..b0cfc189dc 100644 --- a/docs/en/resources/tools/mindsdb/mindsdb-sql.md +++ b/docs/en/resources/tools/mindsdb/mindsdb-sql.md @@ -169,5 +169,5 @@ tools: | source | string | true | Name of the source the SQL should execute on. | | description | string | true | Description of the tool that is passed to the LLM. | | statement | string | true | SQL statement to execute on. | -| parameters | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the SQL statement. | -| templateParameters | [templateParameters](_index#template-parameters) | false | List of [templateParameters](_index#template-parameters) that will be inserted into the SQL statement before executing prepared statement. | +| parameters | [parameters](../#specifying-parameters) | false | List of [parameters](../#specifying-parameters) that will be inserted into the SQL statement. | +| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the SQL statement before executing prepared statement. | diff --git a/docs/en/resources/tools/mongodb/mongodb-find-one.md b/docs/en/resources/tools/mongodb/mongodb-find-one.md index 67e0cca1ce..395262d91a 100644 --- a/docs/en/resources/tools/mongodb/mongodb-find-one.md +++ b/docs/en/resources/tools/mongodb/mongodb-find-one.md @@ -64,5 +64,3 @@ tools: | filterParams | list | false | A list of parameter objects that define the variables used in the `filterPayload`. | | projectPayload | string | false | An optional MongoDB projection document to specify which fields to include (1) or exclude (0) in the result. | | projectParams | list | false | A list of parameter objects for the `projectPayload`. | -| sortPayload | string | false | An optional MongoDB sort document. Useful for selecting which document to return if the filter matches multiple (e.g., get the most recent). | -| sortParams | list | false | A list of parameter objects for the `sortPayload`. | diff --git a/docs/en/resources/tools/mongodb/mongodb-insert-many.md b/docs/en/resources/tools/mongodb/mongodb-insert-many.md index a126f9da29..cc6c385375 100644 --- a/docs/en/resources/tools/mongodb/mongodb-insert-many.md +++ b/docs/en/resources/tools/mongodb/mongodb-insert-many.md @@ -48,11 +48,11 @@ in the `data` parameter, like this: ## Reference -| **field** | **type** | **required** | **description** | -|:------------|:---------|:-------------|:---------------------------------------------------------------------------------------------------| -| kind | string | true | Must be `mongodb-insert-many`. | -| source | string | true | The name of the `mongodb` source to use. | -| description | string | true | A description of the tool that is passed to the LLM. | -| database | string | true | The name of the MongoDB database containing the collection. | -| collection | string | true | The name of the MongoDB collection into which the documents will be inserted. | -| canonical | bool | true | Determines if the data string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. | +| **field** | **type** | **required** | **description** | +|:------------|:---------|:-------------|:------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be `mongodb-insert-many`. | +| source | string | true | The name of the `mongodb` source to use. | +| description | string | true | A description of the tool that is passed to the LLM. | +| database | string | true | The name of the MongoDB database containing the collection. | +| collection | string | true | The name of the MongoDB collection into which the documents will be inserted. | +| canonical | bool | false | Determines if the data string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. Defaults to `false`. | diff --git a/docs/en/resources/tools/mongodb/mongodb-insert-one.md b/docs/en/resources/tools/mongodb/mongodb-insert-one.md index d82a332f38..7214f2b4d6 100644 --- a/docs/en/resources/tools/mongodb/mongodb-insert-one.md +++ b/docs/en/resources/tools/mongodb/mongodb-insert-one.md @@ -43,11 +43,11 @@ An LLM would call this tool by providing the document as a JSON string in the ## Reference -| **field** | **type** | **required** | **description** | -|:------------|:---------|:-------------|:---------------------------------------------------------------------------------------------------| -| kind | string | true | Must be `mongodb-insert-one`. | -| source | string | true | The name of the `mongodb` source to use. | -| description | string | true | A description of the tool that is passed to the LLM. | -| database | string | true | The name of the MongoDB database containing the collection. | -| collection | string | true | The name of the MongoDB collection into which the document will be inserted. | -| canonical | bool | true | Determines if the data string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. | +| **field** | **type** | **required** | **description** | +|:------------|:---------|:-------------|:------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be `mongodb-insert-one`. | +| source | string | true | The name of the `mongodb` source to use. | +| description | string | true | A description of the tool that is passed to the LLM. | +| database | string | true | The name of the MongoDB database containing the collection. | +| collection | string | true | The name of the MongoDB collection into which the document will be inserted. | +| canonical | bool | false | Determines if the data string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. Defaults to `false`. | diff --git a/docs/en/resources/tools/mongodb/mongodb-update-many.md b/docs/en/resources/tools/mongodb/mongodb-update-many.md index 2ed8a1fb6d..ef3b144364 100644 --- a/docs/en/resources/tools/mongodb/mongodb-update-many.md +++ b/docs/en/resources/tools/mongodb/mongodb-update-many.md @@ -57,16 +57,16 @@ tools: ## Reference -| **field** | **type** | **required** | **description** | -|:--------------|:---------|:-------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| kind | string | true | Must be `mongodb-update-many`. | -| source | string | true | The name of the `mongodb` source to use. | -| description | string | true | A description of the tool that is passed to the LLM. | -| database | string | true | The name of the MongoDB database containing the collection. | -| collection | string | true | The name of the MongoDB collection in which to update documents. | -| filterPayload | string | true | The MongoDB query filter document to select the documents for updating. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | -| filterParams | list | false | A list of parameter objects that define the variables used in the `filterPayload`. | -| updatePayload | string | true | The MongoDB update document, It's written as a Go template, using `{{json .param_name}}` to insert parameters. | -| updateParams | list | true | A list of parameter objects that define the variables used in the `updatePayload`. | -| canonical | bool | true | Determines if the `filterPayload` and `updatePayload` strings are parsed using MongoDB's Canonical or Relaxed Extended JSON format. **Canonical** is stricter about type representation, while **Relaxed** is more lenient. | -| upsert | bool | false | If `true`, a new document is created if no document matches the `filterPayload`. Defaults to `false`. | +| **field** | **type** | **required** | **description** | +|:--------------|:---------|:-------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be `mongodb-update-many`. | +| source | string | true | The name of the `mongodb` source to use. | +| description | string | true | A description of the tool that is passed to the LLM. | +| database | string | true | The name of the MongoDB database containing the collection. | +| collection | string | true | The name of the MongoDB collection in which to update documents. | +| filterPayload | string | true | The MongoDB query filter document to select the documents for updating. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | +| filterParams | list | false | A list of parameter objects that define the variables used in the `filterPayload`. | +| updatePayload | string | true | The MongoDB update document, It's written as a Go template, using `{{json .param_name}}` to insert parameters. | +| updateParams | list | true | A list of parameter objects that define the variables used in the `updatePayload`. | +| canonical | bool | false | Determines if the `filterPayload` and `updatePayload` strings are parsed using MongoDB's Canonical or Relaxed Extended JSON format. **Canonical** is stricter about type representation, while **Relaxed** is more lenient. Defaults to `false`. | +| upsert | bool | false | If `true`, a new document is created if no document matches the `filterPayload`. Defaults to `false`. | diff --git a/docs/en/resources/tools/mongodb/mongodb-update-one.md b/docs/en/resources/tools/mongodb/mongodb-update-one.md index 7ecf0662aa..063ea0b192 100644 --- a/docs/en/resources/tools/mongodb/mongodb-update-one.md +++ b/docs/en/resources/tools/mongodb/mongodb-update-one.md @@ -57,16 +57,16 @@ tools: ## Reference -| **field** | **type** | **required** | **description** | -|:--------------|:---------|:-------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| kind | string | true | Must be `mongodb-update-one`. | -| source | string | true | The name of the `mongodb` source to use. | -| description | string | true | A description of the tool that is passed to the LLM. | -| database | string | true | The name of the MongoDB database containing the collection. | -| collection | string | true | The name of the MongoDB collection to update a document in. | -| filterPayload | string | true | The MongoDB query filter document to select the document for updating. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | -| filterParams | list | false | A list of parameter objects that define the variables used in the `filterPayload`. | -| updatePayload | string | true | The MongoDB update document, which specifies the modifications. This often uses update operators like `$set`. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | -| updateParams | list | true | A list of parameter objects that define the variables used in the `updatePayload`. | -| canonical | bool | true | Determines if the `updatePayload` string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. **Canonical** is stricter about type representation (e.g., `{"$numberInt": "42"}`), while **Relaxed** is more lenient (e.g., `42`). | -| upsert | bool | false | If `true`, a new document is created if no document matches the `filterPayload`. Defaults to `false`. | +| **field** | **type** | **required** | **description** | +|:--------------|:---------|:-------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be `mongodb-update-one`. | +| source | string | true | The name of the `mongodb` source to use. | +| description | string | true | A description of the tool that is passed to the LLM. | +| database | string | true | The name of the MongoDB database containing the collection. | +| collection | string | true | The name of the MongoDB collection to update a document in. | +| filterPayload | string | true | The MongoDB query filter document to select the document for updating. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | +| filterParams | list | false | A list of parameter objects that define the variables used in the `filterPayload`. | +| updatePayload | string | true | The MongoDB update document, which specifies the modifications. This often uses update operators like `$set`. It's written as a Go template, using `{{json .param_name}}` to insert parameters. | +| updateParams | list | true | A list of parameter objects that define the variables used in the `updatePayload`. | +| canonical | bool | false | Determines if the `updatePayload` string is parsed using MongoDB's Canonical or Relaxed Extended JSON format. **Canonical** is stricter about type representation (e.g., `{"$numberInt": "42"}`), while **Relaxed** is more lenient (e.g., `42`). Defaults to `false`. | +| upsert | bool | false | If `true`, a new document is created if no document matches the `filterPayload`. Defaults to `false`. | diff --git a/docs/en/resources/tools/mysql/mysql-get-query-plan.md b/docs/en/resources/tools/mysql/mysql-get-query-plan.md new file mode 100644 index 0000000000..d77b81e097 --- /dev/null +++ b/docs/en/resources/tools/mysql/mysql-get-query-plan.md @@ -0,0 +1,39 @@ +--- +title: "mysql-get-query-plan" +type: docs +weight: 1 +description: > + A "mysql-get-query-plan" tool gets the execution plan for a SQL statement against a MySQL + database. +aliases: +- /resources/tools/mysql-get-query-plan +--- + +## About + +A `mysql-get-query-plan` tool gets the execution plan for a SQL statement against a MySQL +database. It's compatible with any of the following sources: + +- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md) +- [mysql](../../sources/mysql.md) + +`mysql-get-query-plan` takes one input parameter `sql_statement` and gets the execution plan for the SQL +statement against the `source`. + +## Example + +```yaml +tools: + get_query_plan_tool: + kind: mysql-get-query-plan + source: my-mysql-instance + description: Use this tool to get the execution plan for a sql statement. +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "mysql-get-query-plan". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | diff --git a/docs/en/resources/tools/postgres/postgres-list-database-stats.md b/docs/en/resources/tools/postgres/postgres-list-database-stats.md new file mode 100644 index 0000000000..01537bcfa7 --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-database-stats.md @@ -0,0 +1,95 @@ +--- +title: "postgres-list-database-stats" +type: docs +weight: 1 +description: > + The "postgres-list-database-stats" tool lists lists key performance and activity statistics of PostgreSQL databases. +aliases: +- /resources/tools/postgres-list-database-stats +--- + +## About + +The `postgres-list-database-stats` lists the key performance and activity statistics for each PostgreSQL database in the instance, offering insights into cache efficiency, transaction throughput, row-level activity, temporary file usage, and contention. It's compatible with +any of the following sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +`postgres-list-database-stats` lists detailed information as JSON for each database. The tool +takes the following input parameters: + +- `database_name` (optional): A text to filter results by database name. Default: `""` +- `include_templates` (optional): Boolean, set to `true` to include template databases in the results. Default: `false` +- `database_owner` (optional): A text to filter results by database owner. Default: `""` +- `default_tablespace` (optional): A text to filter results by the default tablespace name. Default: `""` +- `order_by` (optional): Specifies the sorting order. Valid values are `'size'` (descending) or `'commit'` (descending). Default: `database_name` ascending. +- `limit` (optional): The maximum number of databases to return. Default: `10` + +## Example + +```yaml +tools: + list_database_stats: + kind: postgres-list-database-stats + source: postgres-source + description: | + Lists the key performance and activity statistics for each PostgreSQL + database in the instance, offering insights into cache efficiency, + transaction throughput row-level activity, temporary file usage, and + contention. It returns: the database name, whether the database is + connectable, database owner, default tablespace name, the percentage of + data blocks found in the buffer cache rather than being read from disk + (a higher value indicates better cache performance), the total number of + disk blocks read from disk, the total number of times disk blocks were + found already in the cache; the total number of committed transactions, + the total number of rolled back transactions, the percentage of rolled + back transactions compared to the total number of completed + transactions, the total number of rows returned by queries, the total + number of live rows fetched by scans, the total number of rows inserted, + the total number of rows updated, the total number of rows deleted, the + number of temporary files created by queries, the total size of + temporary files used by queries in bytes, the number of query + cancellations due to conflicts with recovery, the number of deadlocks + detected, the current number of active backend connections, the + timestamp when the database statistics were last reset, and the total + database size in bytes. +``` + +The response is a json array with the following elements: + +```json +{ + "database_name": "Name of the database", + "is_connectable": "Boolean indicating Whether the database allows connections", + "database_owner": "Username of the database owner", + "default_tablespace": "Name of the default tablespace for the database", + "cache_hit_ratio_percent": "The percentage of data blocks found in the buffer cache rather than being read from disk", + "blocks_read_from_disk": "The total number of disk blocks read for this database", + "blocks_hit_in_cache": "The total number of times disk blocks were found already in the cache.", + "xact_commit": "The total number of committed transactions", + "xact_rollback": "The total number of rolled back transactions", + "rollback_ratio_percent": "The percentage of rolled back transactions compared to the total number of completed transactions", + "rows_returned_by_queries": "The total number of rows returned by queries", + "rows_fetched_by_scans": "The total number of live rows fetched by scans", + "tup_inserted": "The total number of rows inserted", + "tup_updated": "The total number of rows updated", + "tup_deleted": "The total number of rows deleted", + "temp_files": "The number of temporary files created by queries", + "temp_size_bytes": "The total size of temporary files used by queries in bytes", + "conflicts": "Number of query cancellations due to conflicts", + "deadlocks": "Number of deadlocks detected", + "active_connections": "The current number of active backend connections", + "statistics_last_reset": "The timestamp when the database statistics were last reset", + "database_size_bytes": "The total disk size of the database in bytes" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|------------------------------------------------------| +| kind | string | true | Must be "postgres-list-database-stats". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | false | Description of the tool that is passed to the agent. | diff --git a/docs/en/resources/tools/postgres/postgres-list-indexes.md b/docs/en/resources/tools/postgres/postgres-list-indexes.md index 3c2cf63e28..f528ae97ec 100644 --- a/docs/en/resources/tools/postgres/postgres-list-indexes.md +++ b/docs/en/resources/tools/postgres/postgres-list-indexes.md @@ -21,12 +21,10 @@ any of the following sources: `postgres-list-indexes` lists detailed information as JSON for indexes. The tool takes the following input parameters: -- `table_name` (optional): A text to filter results by table name. The input is - used within a LIKE clause. Default: `""` -- `index_name` (optional): A text to filter results by index name. The input is - used within a LIKE clause. Default: `""` -- `schema_name` (optional): A text to filter results by schema name. The input - is used within a LIKE clause. Default: `""` +- `table_name` (optional): A text to filter results by table name. Default: `""` +- `index_name` (optional): A text to filter results by index name. Default: `""` +- `schema_name` (optional): A text to filter results by schema name. Default: `""` +- `only_unused` (optional): If true, returns indexes that have never been used. - `limit` (optional): The maximum number of rows to return. Default: `50`. ## Example diff --git a/docs/en/resources/tools/postgres/postgres-list-pg-settings.md b/docs/en/resources/tools/postgres/postgres-list-pg-settings.md new file mode 100644 index 0000000000..23d5e28e92 --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-pg-settings.md @@ -0,0 +1,59 @@ +--- +title: "postgres-list-pg-settings" +type: docs +weight: 1 +description: > + The "postgres-list-pg-settings" tool lists PostgreSQL run-time configuration settings. +aliases: +- /resources/tools/postgres-list-pg-settings +--- + +## About + +The `postgres-list-pg-settings` tool lists the configuration parameters for the postgres server, their current values, and related information. It's compatible with any of the following sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +`postgres-list-pg-settings` lists detailed information as JSON for each setting. The tool +takes the following input parameters: + +- `setting_name` (optional): A text to filter results by setting name. Default: `""` +- `limit` (optional): The maximum number of rows to return. Default: `50`. + +## Example + +```yaml +tools: + list_indexes: + kind: postgres-list-pg-settings + source: postgres-source + description: | + Lists configuration parameters for the postgres server ordered lexicographically, + with a default limit of 50 rows. It returns the parameter name, its current setting, + unit of measurement, a short description, the source of the current setting (e.g., + default, configuration file, session), and whether a restart is required when the + parameter value is changed." +``` + +The response is a json array with the following elements: + +```json +{ + "name": "Setting name", + "current_value": "Current value of the setting", + "unit": "Unit of the setting", + "short_desc": "Short description of the setting", + "source": "Source of the current value (e.g., default, configuration file, session)", + "requires_restart": "Indicates if a server restart is required to apply a change ('Yes', 'No', or 'No (Reload sufficient)')" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|------------------------------------------------------| +| kind | string | true | Must be "postgres-list-pg-settings". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | false | Description of the tool that is passed to the agent. | diff --git a/docs/en/resources/tools/postgres/postgres-list-publication-tables.md b/docs/en/resources/tools/postgres/postgres-list-publication-tables.md new file mode 100644 index 0000000000..a437d11783 --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-publication-tables.md @@ -0,0 +1,66 @@ +--- +title: "postgres-list-publication-tables" +type: docs +weight: 1 +description: > + The "postgres-list-publication-tables" tool lists publication tables in a Postgres database. +aliases: +- /resources/tools/postgres-list-publication-tables +--- + +## About + +The `postgres-list-publication-tables` tool lists all publication tables in the database. It's compatible with any of the following sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +`postgres-list-publication-tables` lists detailed information as JSON for publication tables. A publication table in PostgreSQL is a +table that is explicitly included as a source for replication within a publication (a set of changes generated from a table or group +of tables) as part of the logical replication feature. The tool takes the following input parameters: + +- `table_names` (optional): Filters by a comma-separated list of table names. Default: `""` +- `publication_names` (optional): Filters by a comma-separated list of publication names. Default: `""` +- `schema_names` (optional): Filters by a comma-separated list of schema names. Default: `""` +- `limit` (optional): The maximum number of rows to return. Default: `50` + +## Example + +```yaml +tools: + list_indexes: + kind: postgres-list-publication-tables + source: postgres-source + description: | + Lists all tables that are explicitly part of a publication in the database. + Tables that are part of a publication via 'FOR ALL TABLES' are not included, + unless they are also explicitly added to the publication. + Returns the publication name, schema name, and table name, along with + definition details indicating if it publishes all tables, whether it + replicates inserts, updates, deletes, or truncates, and the publication + owner. +``` + +The response is a JSON array with the following elements: +```json +{ + "publication_name": "Name of the publication", + "schema_name": "Name of the schema the table belongs to", + "table_name": "Name of the table", + "publishes_all_tables": "boolean indicating if the publication was created with FOR ALL TABLES", + "publishes_inserts": "boolean indicating if INSERT operations are replicated", + "publishes_updates": "boolean indicating if UPDATE operations are replicated", + "publishes_deletes": "boolean indicating if DELETE operations are replicated", + "publishes_truncates": "boolean indicating if TRUNCATE operations are replicated", + "publication_owner": "Username of the database role that owns the publication" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|------------------------------------------------------| +| kind | string | true | Must be "postgres-list-publication-tables". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | false | Description of the tool that is passed to the agent. | diff --git a/docs/en/resources/tools/postgres/postgres-list-roles.md b/docs/en/resources/tools/postgres/postgres-list-roles.md new file mode 100644 index 0000000000..d3de6abdfb --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-roles.md @@ -0,0 +1,70 @@ +--- +title: "postgres-list-roles" +type: docs +weight: 1 +description: > + The "postgres-list-roles" tool lists user-created roles in a Postgres database. +aliases: +- /resources/tools/postgres-list-roles +--- + +## About + +The `postgres-list-roles` tool lists all the user-created roles in the instance, excluding system roles (like `cloudsql%` or `pg_%`). It provides details about each role's attributes and memberships. It's compatible with +any of the following sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +`postgres-list-roles` lists detailed information as JSON for each role. The tool +takes the following input parameters: + +- `role_name` (optional): A text to filter results by role name. Default: `""` +- `limit` (optional): The maximum number of roles to return. Default: `50` + +## Example + +```yaml +tools: + list_indexes: + kind: postgres-list-roles + source: postgres-source + description: | + Lists all the user-created roles in the instance . It returns the role name, + Object ID, the maximum number of concurrent connections the role can make, + along with boolean indicators for: superuser status, privilege inheritance + from member roles, ability to create roles, ability to create databases, + ability to log in, replication privilege, and the ability to bypass + row-level security, the password expiration timestamp, a list of direct + members belonging to this role, and a list of other roles/groups that this + role is a member of. +``` + +The response is a json array with the following elements: + +```json +{ + "role_name": "Name of the role", + "oid": "Object ID of the role", + "connection_limit": "Maximum concurrent connections allowed (-1 for no limit)", + "is_superuser": "Boolean, true if the role is a superuser", + "inherits_privileges": "Boolean, true if the role inherits privileges of roles it is a member of", + "can_create_roles": "Boolean, true if the role can create other roles", + "can_create_db": "Boolean, true if the role can create databases", + "can_login": "Boolean, true if the role can log in", + "is_replication_role": "Boolean, true if this is a replication role", + "bypass_rls": "Boolean, true if the role bypasses row-level security policies", + "valid_until": "Timestamp until the password is valid (null if forever)", + "direct_members": ["Array of role names that are direct members of this role"], + "member_of": ["Array of role names that this role is a member of"] +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:------------:|------------------------------------------------------| +| kind | string | true | Must be "postgres-list-roles". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | false | Description of the tool that is passed to the agent. | diff --git a/docs/en/resources/tools/postgres/postgres-list-schemas.md b/docs/en/resources/tools/postgres/postgres-list-schemas.md index 9147ce50ef..6c0eb9e82a 100644 --- a/docs/en/resources/tools/postgres/postgres-list-schemas.md +++ b/docs/en/resources/tools/postgres/postgres-list-schemas.md @@ -21,9 +21,9 @@ the following sources: `postgres-list-schemas` lists detailed information as JSON for each schema. The tool takes the following input parameters: -- `schema_name` (optional): A pattern to filter schema names using SQL LIKE - operator. - If omitted, all user-defined schemas are returned. +- `schema_name` (optional): A text to filter results by schema name. Default: `""` +- `owner` (optional): A text to filter results by owner name. Default: `""` +- `limit` (optional): The maximum number of rows to return. Default: `50`. ## Example diff --git a/docs/en/resources/tools/postgres/postgres-list-sequences.md b/docs/en/resources/tools/postgres/postgres-list-sequences.md index 98679b8302..e51915525a 100644 --- a/docs/en/resources/tools/postgres/postgres-list-sequences.md +++ b/docs/en/resources/tools/postgres/postgres-list-sequences.md @@ -20,9 +20,9 @@ Postgres database. It's compatible with any of the following sources: `postgres-list-sequences` lists detailed information as JSON for all sequences. The tool takes the following input parameters: -- `sequencename` (optional): A text to filter results by sequence name. The +- `sequence_name` (optional): A text to filter results by sequence name. The input is used within a LIKE clause. Default: `""` -- `schemaname` (optional): A text to filter results by schema name. The input is +- `schema_name` (optional): A text to filter results by schema name. The input is used within a LIKE clause. Default: `""` - `limit` (optional): The maximum number of rows to return. Default: `50`. @@ -45,9 +45,9 @@ The response is a json array with the following elements: ```json { - "sequencename": "sequence name", - "schemaname": "schema name", - "sequenceowner": "owner of the sequence", + "sequence_name": "sequence name", + "schema_name": "schema name", + "sequence_owner": "owner of the sequence", "data_type": "data type of the sequence", "start_value": "starting value of the sequence", "min_value": "minimum value of the sequence", diff --git a/docs/en/resources/tools/postgres/postgres-list-table-stats.md b/docs/en/resources/tools/postgres/postgres-list-table-stats.md new file mode 100644 index 0000000000..666a126aca --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-table-stats.md @@ -0,0 +1,171 @@ +--- +title: "postgres-list-table-stats" +type: docs +weight: 1 +description: > + The "postgres-list-table-stats" tool reports table statistics including size, scan metrics, and bloat indicators for PostgreSQL tables. +aliases: +- /resources/tools/postgres-list-table-stats +--- + +## About + +The `postgres-list-table-stats` tool queries `pg_stat_all_tables` to provide comprehensive statistics about tables in the database. It calculates useful metrics like index scan ratio and dead row ratio to help identify performance issues and table bloat. + +Compatible sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +The tool returns a JSON array where each element represents statistics for a table, including scan metrics, row counts, and vacuum history. Results are sorted by sequential scans by default and limited to 50 rows. + +## Example + +```yaml +tools: + list_table_stats: + kind: postgres-list-table-stats + source: postgres-source + description: "Lists table statistics including size, scans, and bloat metrics." +``` + +### Example Requests + +**List default tables in public schema:** +```json +{} +``` + +**Filter by specific table name:** +```json +{ + "table_name": "users" +} +``` + +**Filter by owner and sort by size:** +```json +{ + "owner": "app_user", + "sort_by": "size", + "limit": 10 +} +``` + +**Find tables with high dead row ratio:** +```json +{ + "sort_by": "dead_rows", + "limit": 20 +} +``` + +### Example Response + +```json +[ + { + "schema_name": "public", + "table_name": "users", + "owner": "postgres", + "total_size_bytes": 8388608, + "seq_scan": 150, + "idx_scan": 450, + "idx_scan_ratio_percent": 75.0, + "live_rows": 50000, + "dead_rows": 1200, + "dead_row_ratio_percent": 2.34, + "n_tup_ins": 52000, + "n_tup_upd": 12500, + "n_tup_del": 800, + "last_vacuum": "2025-11-27T10:30:00Z", + "last_autovacuum": "2025-11-27T09:15:00Z", + "last_autoanalyze": "2025-11-27T09:16:00Z" + }, + { + "schema_name": "public", + "table_name": "orders", + "owner": "postgres", + "total_size_bytes": 16777216, + "seq_scan": 50, + "idx_scan": 1200, + "idx_scan_ratio_percent": 96.0, + "live_rows": 100000, + "dead_rows": 5000, + "dead_row_ratio_percent": 4.76, + "n_tup_ins": 120000, + "n_tup_upd": 45000, + "n_tup_del": 15000, + "last_vacuum": "2025-11-26T14:22:00Z", + "last_autovacuum": "2025-11-27T02:30:00Z", + "last_autoanalyze": "2025-11-27T02:31:00Z" + } +] +``` + +## Parameters + +| parameter | type | required | default | description | +|-------------|---------|----------|---------|-------------| +| schema_name | string | false | "public" | Optional: A specific schema name to filter by (supports partial matching) | +| table_name | string | false | null | Optional: A specific table name to filter by (supports partial matching) | +| owner | string | false | null | Optional: A specific owner to filter by (supports partial matching) | +| sort_by | string | false | null | Optional: The column to sort by. Valid values: `size`, `dead_rows`, `seq_scan`, `idx_scan` (defaults to `seq_scan`) | +| limit | integer | false | 50 | Optional: The maximum number of results to return | + +## Output Fields Reference + +| field | type | description | +|------------------------|-----------|-------------| +| schema_name | string | Name of the schema containing the table. | +| table_name | string | Name of the table. | +| owner | string | PostgreSQL user who owns the table. | +| total_size_bytes | integer | Total size of the table including all indexes in bytes. | +| seq_scan | integer | Number of sequential (full table) scans performed on this table. | +| idx_scan | integer | Number of index scans performed on this table. | +| idx_scan_ratio_percent | decimal | Percentage of total scans (seq_scan + idx_scan) that used an index. A low ratio may indicate missing or ineffective indexes. | +| live_rows | integer | Number of live (non-deleted) rows in the table. | +| dead_rows | integer | Number of dead (deleted but not yet vacuumed) rows in the table. | +| dead_row_ratio_percent | decimal | Percentage of dead rows relative to total rows. High values indicate potential table bloat. | +| n_tup_ins | integer | Total number of rows inserted into this table. | +| n_tup_upd | integer | Total number of rows updated in this table. | +| n_tup_del | integer | Total number of rows deleted from this table. | +| last_vacuum | timestamp | Timestamp of the last manual VACUUM operation on this table (null if never manually vacuumed). | +| last_autovacuum | timestamp | Timestamp of the last automatic vacuum operation on this table. | +| last_autoanalyze | timestamp | Timestamp of the last automatic analyze operation on this table. | + +## Interpretation Guide + +### Index Scan Ratio (`idx_scan_ratio_percent`) + +- **High ratio (> 80%)**: Table queries are efficiently using indexes. This is typically desirable. +- **Low ratio (< 20%)**: Many sequential scans indicate missing indexes or queries that cannot use existing indexes effectively. Consider adding indexes to frequently searched columns. +- **0%**: No index scans performed; all queries performed sequential scans. May warrant index investigation. + +### Dead Row Ratio (`dead_row_ratio_percent`) + +- **< 2%**: Healthy table with minimal bloat. +- **2-5%**: Moderate bloat; consider running VACUUM if not recent. +- **> 5%**: High bloat; may benefit from manual VACUUM or VACUUM FULL. + +### Vacuum History + +- **Null `last_vacuum`**: Table has never been manually vacuumed; relies on autovacuum. +- **Recent `last_autovacuum`**: Autovacuum is actively managing the table. +- **Stale timestamps**: Consider running manual VACUUM and ANALYZE if maintenance windows exist. + +## Performance Considerations + +- Statistics are collected from `pg_stat_all_tables`, which resets on PostgreSQL restart. +- Run `ANALYZE` on tables to update statistics for accurate query planning. +- The tool defaults to limiting results to 50 rows; adjust the `limit` parameter for larger result sets. +- Filtering by schema, table name, or owner uses `LIKE` pattern matching (supports partial matches). + +## Use Cases + +- **Finding ineffective indexes**: Identify tables with low `idx_scan_ratio_percent` to evaluate index strategy. +- **Detecting table bloat**: Sort by `dead_rows` to find tables needing VACUUM. +- **Monitoring growth**: Track `total_size_bytes` over time for capacity planning. +- **Audit maintenance**: Check `last_autovacuum` and `last_autoanalyze` timestamps to ensure maintenance tasks are running. +- **Understanding workload**: Examine `seq_scan` vs `idx_scan` ratios to understand query patterns. \ No newline at end of file diff --git a/docs/en/resources/tools/postgres/postgres-list-tablespaces.md b/docs/en/resources/tools/postgres/postgres-list-tablespaces.md new file mode 100644 index 0000000000..bf63f61b8f --- /dev/null +++ b/docs/en/resources/tools/postgres/postgres-list-tablespaces.md @@ -0,0 +1,56 @@ +--- +title: "postgres-list-tablespaces" +type: docs +weight: 1 +description: > + The "postgres-list-tablespaces" tool lists tablespaces in a Postgres database. +aliases: +- /resources/tools/postgres-list-tablespaces +--- + +## About + +The `postgres-list-tablespaces` tool lists available tablespaces in the database. It's compatible with any of the following sources: + +- [alloydb-postgres](../../sources/alloydb-pg.md) +- [cloud-sql-postgres](../../sources/cloud-sql-pg.md) +- [postgres](../../sources/postgres.md) + +`postgres-list-tablespaces` lists detailed information as JSON for tablespaces. The tool takes the following input parameters: + +- `tablespace_name` (optional): A text to filter results by tablespace name. Default: `""` +- `limit` (optional): The maximum number of tablespaces to return. Default: `50` + +## Example + +```yaml +tools: + list_tablespaces: + kind: postgres-list-tablespaces + source: postgres-source + description: | + Lists all tablespaces in the database. Returns the tablespace name, + owner name, size in bytes(if the current user has CREATE privileges on + the tablespace, otherwise NULL), internal object ID, the access control + list regarding permissions, and any specific tablespace options. +``` +The response is a json array with the following elements: + +```json +{ + "tablespace_name": "name of the tablespace", + "owner_username": "owner of the tablespace", + "size_in_bytes": "size in bytes if the current user has CREATE privileges on the tablespace, otherwise NULL", + "oid": "Object ID of the tablespace", + "spcacl": "Access privileges", + "spcoptions": "Tablespace-level options (e.g., seq_page_cost, random_page_cost)" +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|-------------|:--------:|:-------------:|------------------------------------------------------| +| kind | string | true | Must be "postgres-list-tablespaces". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | false | Description of the tool that is passed to the agent. | diff --git a/docs/en/resources/tools/postgres/postgres-list-views.md b/docs/en/resources/tools/postgres/postgres-list-views.md index 7823252abc..b6f1687f33 100644 --- a/docs/en/resources/tools/postgres/postgres-list-views.md +++ b/docs/en/resources/tools/postgres/postgres-list-views.md @@ -19,11 +19,11 @@ a Postgres database, excluding those in system schemas (`pg_catalog`, - [postgres](../../sources/postgres.md) `postgres-list-views` lists detailed view information (schemaname, viewname, -ownername) as JSON for views in a database. The tool takes the following input +ownername, definition) as JSON for views in a database. The tool takes the following input parameters: -- `viewname` (optional): A string pattern to filter view names. The search uses - SQL LIKE operator to filter the views. Default: `""` +- `view_name` (optional): A string pattern to filter view names. Default: `""` +- `schema_name` (optional): A string pattern to filter schema names. Default: `""` - `limit` (optional): The maximum number of rows to return. Default: `50`. ## Example diff --git a/docs/en/resources/tools/serverless-spark/_index.md b/docs/en/resources/tools/serverless-spark/_index.md index 4974a07b19..fab49745e6 100644 --- a/docs/en/resources/tools/serverless-spark/_index.md +++ b/docs/en/resources/tools/serverless-spark/_index.md @@ -9,3 +9,5 @@ description: > - [serverless-spark-get-batch](./serverless-spark-get-batch.md) - [serverless-spark-list-batches](./serverless-spark-list-batches.md) - [serverless-spark-cancel-batch](./serverless-spark-cancel-batch.md) +- [serverless-spark-create-pyspark-batch](./serverless-spark-create-pyspark-batch.md) +- [serverless-spark-create-spark-batch](./serverless-spark-create-spark-batch.md) diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md new file mode 100644 index 0000000000..b94d386b2d --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-create-pyspark-batch.md @@ -0,0 +1,97 @@ +--- +title: "serverless-spark-create-pyspark-batch" +type: docs +weight: 2 +description: > + A "serverless-spark-create-pyspark-batch" tool submits a Spark batch to run asynchronously. +aliases: + - /resources/tools/serverless-spark-create-pyspark-batch +--- + +## About + +A `serverless-spark-create-pyspark-batch` tool submits a Spark batch to a Google +Cloud Serverless for Apache Spark source. The workload executes asynchronously +and takes around a minute to begin executing; status can be polled using the +[get batch](serverless-spark-get-batch.md) tool. + +It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-create-pyspark-batch` accepts the following parameters: + +- **`mainFile`**: The path to the main Python file, as a gs://... URI. +- **`args`** Optional. A list of arguments passed to the main file. +- **`version`** Optional. The Serverless [runtime + version](https://docs.cloud.google.com/dataproc-serverless/docs/concepts/versions/dataproc-serverless-versions) + to execute with. + +## Custom Configuration + +This tool supports custom +[`runtimeConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) +and +[`environmentConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) +settings, which can be specified in a `tools.yaml` file. These configurations +are parsed as YAML and passed to the Dataproc API. + +**Note:** If your project requires custom runtime or environment configuration, +you must write a custom `tools.yaml`, you cannot use the `serverless-spark` +prebuilt config. + +### Example `tools.yaml` + +```yaml +tools: + - name: "serverless-spark-create-pyspark-batch" + kind: "serverless-spark-create-pyspark-batch" + source: "my-serverless-spark-source" + runtimeConfig: + properties: + spark.driver.memory: "1024m" + environmentConfig: + executionConfig: + networkUri: "my-network" +``` + +## Response Format + +The response contains the +[operation](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.operations#resource:-operation) +metadata JSON object corresponding to [batch operation +metadata](https://pkg.go.dev/cloud.google.com/go/dataproc/v2/apiv1/dataprocpb#BatchOperationMetadata), +plus additional fields `consoleUrl` and `logsUrl` where a human can go for more +detailed information. + +```json +{ + "opMetadata": { + "batch": "projects/myproject/locations/us-central1/batches/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "batchUuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "createTime": "2025-11-19T16:36:47.607119Z", + "description": "Batch", + "labels": { + "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "goog-dataproc-location": "us-central1" + }, + "operationType": "BATCH", + "warnings": [ + "No runtime version specified. Using the default runtime version." + ] + }, + "consoleUrl": "https://console.cloud.google.com/dataproc/batches/...", + "logsUrl": "https://console.cloud.google.com/logs/viewer?..." +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ----------------- | :------: | :----------: | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-create-pyspark-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | false | Description of the tool that is passed to the LLM. | +| runtimeConfig | map | false | [Runtime config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) for all batches created with this tool. | +| environmentConfig | map | false | [Environment config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) for all batches created with this tool. | +| authRequired | string[] | false | List of auth services required to invoke this tool. | diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-create-spark-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-create-spark-batch.md new file mode 100644 index 0000000000..8264be00b0 --- /dev/null +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-create-spark-batch.md @@ -0,0 +1,102 @@ +--- +title: "serverless-spark-create-spark-batch" +type: docs +weight: 2 +description: > + A "serverless-spark-create-spark-batch" tool submits a Spark batch to run asynchronously. +aliases: + - /resources/tools/serverless-spark-create-spark-batch +--- + +## About + +A `serverless-spark-create-spark-batch` tool submits a Java Spark batch to a +Google Cloud Serverless for Apache Spark source. The workload executes +asynchronously and takes around a minute to begin executing; status can be +polled using the [get batch](serverless-spark-get-batch.md) tool. + +It's compatible with the following sources: + +- [serverless-spark](../../sources/serverless-spark.md) + +`serverless-spark-create-spark-batch` accepts the following parameters: + +- **`mainJarFile`**: Optional. The gs:// URI of the jar file that contains the + main class. Exactly one of mainJarFile or mainClass must be specified. +- **`mainClass`**: Optional. The name of the driver's main class. Exactly one of + mainJarFile or mainClass must be specified. +- **`jarFiles`**: Optional. A list of gs:// URIs of jar files to add to the CLASSPATHs of + the Spark driver and tasks. +- **`args`** Optional. A list of arguments passed to the driver. +- **`version`** Optional. The Serverless [runtime + version](https://docs.cloud.google.com/dataproc-serverless/docs/concepts/versions/dataproc-serverless-versions) + to execute with. + +## Custom Configuration + +This tool supports custom +[`runtimeConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) +and +[`environmentConfig`](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) +settings, which can be specified in a `tools.yaml` file. These configurations +are parsed as YAML and passed to the Dataproc API. + +**Note:** If your project requires custom runtime or environment configuration, +you must write a custom `tools.yaml`, you cannot use the `serverless-spark` +prebuilt config. + +### Example `tools.yaml` + +```yaml +tools: + - name: "serverless-spark-create-spark-batch" + kind: "serverless-spark-create-spark-batch" + source: "my-serverless-spark-source" + runtimeConfig: + properties: + spark.driver.memory: "1024m" + environmentConfig: + executionConfig: + networkUri: "my-network" +``` + +## Response Format + +The response contains the +[operation](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.operations#resource:-operation) +metadata JSON object corresponding to [batch operation +metadata](https://pkg.go.dev/cloud.google.com/go/dataproc/v2/apiv1/dataprocpb#BatchOperationMetadata), +plus additional fields `consoleUrl` and `logsUrl` where a human can go for more +detailed information. + +```json +{ + "opMetadata": { + "batch": "projects/myproject/locations/us-central1/batches/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "batchUuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "createTime": "2025-11-19T16:36:47.607119Z", + "description": "Batch", + "labels": { + "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "goog-dataproc-location": "us-central1" + }, + "operationType": "BATCH", + "warnings": [ + "No runtime version specified. Using the default runtime version." + ] + }, + "consoleUrl": "https://console.cloud.google.com/dataproc/batches/...", + "logsUrl": "https://console.cloud.google.com/logs/viewer?..." +} +``` + +## Reference + +| **field** | **type** | **required** | **description** | +| ----------------- | :------: | :----------: | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| kind | string | true | Must be "serverless-spark-create-spark-batch". | +| source | string | true | Name of the source the tool should use. | +| description | string | false | Description of the tool that is passed to the LLM. | +| runtimeConfig | map | false | [Runtime config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/RuntimeConfig) for all batches created with this tool. | +| environmentConfig | map | false | [Environment config](https://docs.cloud.google.com/dataproc-serverless/docs/reference/rest/v1/EnvironmentConfig) for all batches created with this tool. | +| authRequired | string[] | false | List of auth services required to invoke this tool. | diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md b/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md index 532af65344..754aab9fd9 100644 --- a/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-get-batch.md @@ -34,43 +34,50 @@ tools: ## Response Format -The response is a full Batch JSON object as defined in the [API -spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch). -Example with a reduced set of fields: +The response contains the full Batch object as defined in the [API +spec](https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#Batch), +plus additional fields `consoleUrl` and `logsUrl` where a human can go for more +detailed information. ```json { - "createTime": "2025-10-10T15:15:21.303146Z", - "creator": "alice@example.com", - "labels": { - "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", - "goog-dataproc-location": "us-central1" - }, - "name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd", - "operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555", - "runtimeConfig": { - "properties": { - "spark:spark.driver.cores": "4", - "spark:spark.driver.memory": "12200m" - } - }, - "sparkBatch": { - "jarFileUris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "mainClass": "org.apache.spark.examples.SparkPi" - }, - "state": "SUCCEEDED", - "stateHistory": [ - { - "state": "PENDING", - "stateStartTime": "2025-10-10T15:15:21.303146Z" + "batch": { + "createTime": "2025-10-10T15:15:21.303146Z", + "creator": "alice@example.com", + "labels": { + "goog-dataproc-batch-uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + "goog-dataproc-location": "us-central1" }, - { - "state": "RUNNING", - "stateStartTime": "2025-10-10T15:16:41.291747Z" - } - ], - "stateTime": "2025-10-10T15:17:21.265493Z", - "uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + "name": "projects/google.com:hadoop-cloud-dev/locations/us-central1/batches/alice-20251010-abcd", + "operation": "projects/google.com:hadoop-cloud-dev/regions/us-central1/operations/11111111-2222-3333-4444-555555555555", + "runtimeConfig": { + "properties": { + "spark:spark.driver.cores": "4", + "spark:spark.driver.memory": "12200m" + } + }, + "sparkBatch": { + "jarFileUris": [ + "file:///usr/lib/spark/examples/jars/spark-examples.jar" + ], + "mainClass": "org.apache.spark.examples.SparkPi" + }, + "state": "SUCCEEDED", + "stateHistory": [ + { + "state": "PENDING", + "stateStartTime": "2025-10-10T15:15:21.303146Z" + }, + { + "state": "RUNNING", + "stateStartTime": "2025-10-10T15:16:41.291747Z" + } + ], + "stateTime": "2025-10-10T15:17:21.265493Z", + "uuid": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + }, + "consoleUrl": "https://console.cloud.google.com/dataproc/batches/...", + "logsUrl": "https://console.cloud.google.com/logs/viewer?..." } ``` diff --git a/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md b/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md index 54d68eaa2e..9f0e5f0e7c 100644 --- a/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md +++ b/docs/en/resources/tools/serverless-spark/serverless-spark-list-batches.md @@ -50,14 +50,18 @@ tools: "uuid": "a1b2c3d4-e5f6-7890-1234-567890abcdef", "state": "SUCCEEDED", "creator": "alice@example.com", - "createTime": "2023-10-27T10:00:00Z" + "createTime": "2023-10-27T10:00:00Z", + "consoleUrl": "https://console.cloud.google.com/dataproc/batches/us-central1/batch-abc-123/summary?project=my-project", + "logsUrl": "https://console.cloud.google.com/logs/viewer?advancedFilter=resource.type%3D%22cloud_dataproc_batch%22%0Aresource.labels.project_id%3D%22my-project%22%0Aresource.labels.location%3D%22us-central1%22%0Aresource.labels.batch_id%3D%22batch-abc-123%22%0Atimestamp%3E%3D%222023-10-27T09%3A59%3A00Z%22%0Atimestamp%3C%3D%222023-10-27T10%3A10%3A00Z%22&project=my-project&resource=cloud_dataproc_batch%2Fbatch_id%2Fbatch-abc-123" }, { "name": "projects/my-project/locations/us-central1/batches/batch-def-456", "uuid": "b2c3d4e5-f6a7-8901-2345-678901bcdefa", "state": "FAILED", "creator": "alice@example.com", - "createTime": "2023-10-27T11:30:00Z" + "createTime": "2023-10-27T11:30:00Z", + "consoleUrl": "https://console.cloud.google.com/dataproc/batches/us-central1/batch-def-456/summary?project=my-project", + "logsUrl": "https://console.cloud.google.com/logs/viewer?advancedFilter=resource.type%3D%22cloud_dataproc_batch%22%0Aresource.labels.project_id%3D%22my-project%22%0Aresource.labels.location%3D%22us-central1%22%0Aresource.labels.batch_id%3D%22batch-def-456%22%0Atimestamp%3E%3D%222023-10-27T11%3A29%3A00Z%22%0Atimestamp%3C%3D%222023-10-27T11%3A40%3A00Z%22&project=my-project&resource=cloud_dataproc_batch%2Fbatch_id%2Fbatch-def-456" } ], "nextPageToken": "abcd1234" diff --git a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb index 10f1ce2e2d..fc8e5300b1 100644 --- a/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb +++ b/docs/en/samples/alloydb/ai-nl/alloydb_ai_nl.ipynb @@ -771,7 +771,7 @@ }, "outputs": [], "source": [ - "version = \"0.21.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -L -o /content/toolbox https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/alloydb/mcp_quickstart.md b/docs/en/samples/alloydb/mcp_quickstart.md index 80e1e57a0c..c047416428 100644 --- a/docs/en/samples/alloydb/mcp_quickstart.md +++ b/docs/en/samples/alloydb/mcp_quickstart.md @@ -123,7 +123,7 @@ In this section, we will download and install the Toolbox binary. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - export VERSION="0.21.0" + export VERSION="0.24.0" curl -O https://storage.googleapis.com/genai-toolbox/v$VERSION/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb index 292da7be83..eb551ca015 100644 --- a/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb +++ b/docs/en/samples/bigquery/colab_quickstart_bigquery.ipynb @@ -220,7 +220,7 @@ }, "outputs": [], "source": [ - "version = \"0.21.0\" # x-release-please-version\n", + "version = \"0.24.0\" # x-release-please-version\n", "! curl -O https://storage.googleapis.com/genai-toolbox/v{version}/linux/amd64/toolbox\n", "\n", "# Make the binary executable\n", diff --git a/docs/en/samples/bigquery/local_quickstart.md b/docs/en/samples/bigquery/local_quickstart.md index 2971fbdaeb..badda3f75e 100644 --- a/docs/en/samples/bigquery/local_quickstart.md +++ b/docs/en/samples/bigquery/local_quickstart.md @@ -179,7 +179,7 @@ to use BigQuery, and then run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/bigquery/mcp_quickstart/_index.md b/docs/en/samples/bigquery/mcp_quickstart/_index.md index ac1556abaa..6f0b44d18b 100644 --- a/docs/en/samples/bigquery/mcp_quickstart/_index.md +++ b/docs/en/samples/bigquery/mcp_quickstart/_index.md @@ -98,7 +98,7 @@ In this section, we will download Toolbox, configure our tools in a ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini.md b/docs/en/samples/looker/looker_gemini.md index 2b7e67bc1a..0fc81afc32 100644 --- a/docs/en/samples/looker/looker_gemini.md +++ b/docs/en/samples/looker/looker_gemini.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_gemini_oauth/_index.md b/docs/en/samples/looker/looker_gemini_oauth/_index.md index 537ab38ef7..6eb730ceee 100644 --- a/docs/en/samples/looker/looker_gemini_oauth/_index.md +++ b/docs/en/samples/looker/looker_gemini_oauth/_index.md @@ -48,7 +48,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/docs/en/samples/looker/looker_mcp_inspector/_index.md b/docs/en/samples/looker/looker_mcp_inspector/_index.md index 100ef9081a..ef3a51c4e9 100644 --- a/docs/en/samples/looker/looker_mcp_inspector/_index.md +++ b/docs/en/samples/looker/looker_mcp_inspector/_index.md @@ -34,7 +34,7 @@ In this section, we will download Toolbox and run the Toolbox server. ```bash export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64 - curl -O https://storage.googleapis.com/genai-toolbox/v0.21.0/$OS/toolbox + curl -O https://storage.googleapis.com/genai-toolbox/v0.24.0/$OS/toolbox ``` diff --git a/gemini-extension.json b/gemini-extension.json index 70d2524eb8..b068279cd6 100644 --- a/gemini-extension.json +++ b/gemini-extension.json @@ -1,6 +1,6 @@ { "name": "mcp-toolbox-for-databases", - "version": "0.21.0", + "version": "0.24.0", "description": "MCP Toolbox for Databases is an open-source MCP server for more than 30 different datasources.", "contextFileName": "MCP-TOOLBOX-EXTENSION.md" } \ No newline at end of file diff --git a/go.mod b/go.mod index 7b6adf3074..e0ed921ac5 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/googleapis/genai-toolbox go 1.24.7 -toolchain go1.25.3 +toolchain go1.25.5 require ( cloud.google.com/go/alloydbconn v1.15.5 @@ -12,7 +12,7 @@ require ( cloud.google.com/go/dataplex v1.28.0 cloud.google.com/go/dataproc/v2 v2.15.0 cloud.google.com/go/firestore v1.20.0 - cloud.google.com/go/geminidataanalytics v0.2.1 + cloud.google.com/go/geminidataanalytics v0.3.0 cloud.google.com/go/longrunning v0.7.0 cloud.google.com/go/spanner v1.86.1 github.com/ClickHouse/clickhouse-go/v2 v2.40.3 @@ -22,25 +22,27 @@ require ( github.com/cenkalti/backoff/v5 v5.0.3 github.com/couchbase/gocb/v2 v2.11.1 github.com/couchbase/tools-common/http v1.0.9 - github.com/elastic/elastic-transport-go/v8 v8.7.0 + github.com/elastic/elastic-transport-go/v8 v8.8.0 github.com/elastic/go-elasticsearch/v9 v9.2.0 github.com/fsnotify/fsnotify v1.9.0 github.com/go-chi/chi/v5 v5.2.3 + github.com/go-chi/cors v1.2.2 github.com/go-chi/httplog/v2 v2.1.1 github.com/go-chi/render v1.0.3 github.com/go-goquery/goquery v1.0.1 github.com/go-playground/validator/v10 v10.28.0 github.com/go-sql-driver/mysql v1.9.3 github.com/goccy/go-yaml v1.18.0 + github.com/godror/godror v0.49.6 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.6 github.com/json-iterator/go v1.1.12 - github.com/looker-open-source/sdk-codegen/go v0.25.18 + github.com/looker-open-source/sdk-codegen/go v0.25.21 github.com/microsoft/go-mssqldb v1.9.3 github.com/nakagami/firebirdsql v0.9.15 github.com/neo4j/neo4j-go-driver/v5 v5.28.4 - github.com/redis/go-redis/v9 v9.16.0 + github.com/redis/go-redis/v9 v9.17.2 github.com/sijms/go-ora/v2 v2.9.0 github.com/spf13/cobra v1.10.1 github.com/thlib/go-timezone-local v0.0.7 @@ -90,6 +92,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect github.com/PuerkitoBio/goquery v1.10.3 // indirect + github.com/VictoriaMetrics/easyproto v0.1.4 // indirect github.com/ajg/form v1.5.1 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -106,11 +109,13 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect github.com/go-jose/go-jose/v4 v4.1.2 // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/goccy/go-json v0.10.5 // indirect + github.com/godror/knownpb v0.3.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect @@ -180,7 +185,7 @@ require ( golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.38.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect google.golang.org/grpc v1.76.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 0440488508..eeac2b4fd4 100644 --- a/go.sum +++ b/go.sum @@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2 cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w= cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM= cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0= -cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ= -cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0= +cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI= +cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg= cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60= cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo= cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg= @@ -683,6 +683,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8 github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo= github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y= +github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4= +github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc= +github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc= +github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= @@ -818,8 +822,8 @@ github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3 github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/elastic/elastic-transport-go/v8 v8.7.0 h1:OgTneVuXP2uip4BA658Xi6Hfw+PeIOod2rY3GVMGoVE= -github.com/elastic/elastic-transport-go/v8 v8.7.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= +github.com/elastic/elastic-transport-go/v8 v8.8.0 h1:7k1Ua+qluFr6p1jfJjGDl97ssJS/P7cHNInzfxgBQAo= +github.com/elastic/elastic-transport-go/v8 v8.8.0/go.mod h1:YLHer5cj0csTzNFXoNQ8qhtGY1GTvSqPnKWKaqQE3Hk= github.com/elastic/go-elasticsearch/v9 v9.2.0 h1:COeL/g20+ixnUbffe4Wfbu88emrHjAq/LhVfmrjqRQs= github.com/elastic/go-elasticsearch/v9 v9.2.0/go.mod h1:2PB5YQPpY5tWbF65MRqzEXA31PZOdXCkloQSOZtU14I= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -858,6 +862,8 @@ github.com/gabriel-vasile/mimetype v1.4.10/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9t github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= +github.com/go-chi/cors v1.2.2/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/httplog/v2 v2.1.1 h1:ojojiu4PIaoeJ/qAO4GWUxJqvYUTobeo7zmuHQJAxRk= github.com/go-chi/httplog/v2 v2.1.1/go.mod h1:/XXdxicJsp4BA5fapgIC3VuTD+z0Z/VzukoB3VDc1YE= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= @@ -882,6 +888,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -907,6 +915,10 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM= +github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8= +github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw= +github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= @@ -1132,8 +1144,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= -github.com/looker-open-source/sdk-codegen/go v0.25.18 h1:me1JBFRnOBCrDWwpoSUVDVDFcFmcYMR2ijbx6ATtwTs= -github.com/looker-open-source/sdk-codegen/go v0.25.18/go.mod h1:Br1ntSiruDJ/4nYNjpYyWyCbqJ7+GQceWbIgn0hYims= +github.com/looker-open-source/sdk-codegen/go v0.25.21 h1:nlZ1nz22SKluBNkzplrMHBPEVgJO3zVLF6aAws1rrRA= +github.com/looker-open-source/sdk-codegen/go v0.25.21/go.mod h1:Br1ntSiruDJ/4nYNjpYyWyCbqJ7+GQceWbIgn0hYims= github.com/lyft/protoc-gen-star v0.6.0/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA= github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o= @@ -1170,6 +1182,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4= github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k= +github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc= +github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -1208,8 +1222,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.3.0/go.mod h1:LDGWKZIo7rky3hgvBe+caln+Dr3dPggB5dvjtD7w9+w= -github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4= -github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI= +github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= @@ -1669,6 +1683,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1988,8 +2004,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0= google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0= -google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM= +google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4= +google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U= google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY= google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/internal/prebuiltconfigs/tools/alloydb-postgres.yaml b/internal/prebuiltconfigs/tools/alloydb-postgres.yaml index 482c19512d..45b7785208 100644 --- a/internal/prebuiltconfigs/tools/alloydb-postgres.yaml +++ b/internal/prebuiltconfigs/tools/alloydb-postgres.yaml @@ -175,7 +175,7 @@ tools: list_schemas: kind: postgres-list-schemas source: alloydb-pg-source - + list_indexes: kind: postgres-list-indexes source: alloydb-pg-source @@ -200,6 +200,30 @@ tools: kind: postgres-get-column-cardinality source: alloydb-pg-source + list_table_stats: + kind: postgres-list-table-stats + source: alloydb-pg-source + + list_publication_tables: + kind: postgres-list-publication-tables + source: alloydb-pg-source + + list_tablespaces: + kind: postgres-list-tablespaces + source: alloydb-pg-source + + list_pg_settings: + kind: postgres-list-pg-settings + source: alloydb-pg-source + + list_database_stats: + kind: postgres-list-database-stats + source: alloydb-pg-source + + list_roles: + kind: postgres-list-roles + source: alloydb-pg-source + toolsets: alloydb_postgres_database_tools: - execute_sql @@ -224,3 +248,9 @@ toolsets: - replication_stats - list_query_stats - get_column_cardinality + - list_publication_tables + - list_tablespaces + - list_pg_settings + - list_database_stats + - list_roles + - list_table_stats diff --git a/internal/prebuiltconfigs/tools/cloud-sql-mssql-admin.yaml b/internal/prebuiltconfigs/tools/cloud-sql-mssql-admin.yaml index 7e488fc99a..7830ea45cb 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-mssql-admin.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-mssql-admin.yaml @@ -39,6 +39,10 @@ tools: wait_for_operation: kind: cloud-sql-wait-for-operation source: cloud-sql-admin-source + multiplier: 4 + clone_instance: + kind: cloud-sql-clone-instance + source: cloud-sql-admin-source toolsets: cloud_sql_mssql_admin_tools: @@ -49,3 +53,4 @@ toolsets: - list_databases - create_user - wait_for_operation + - clone_instance diff --git a/internal/prebuiltconfigs/tools/cloud-sql-mysql-admin.yaml b/internal/prebuiltconfigs/tools/cloud-sql-mysql-admin.yaml index 3262638eef..145f1cbc33 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-mysql-admin.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-mysql-admin.yaml @@ -39,6 +39,10 @@ tools: wait_for_operation: kind: cloud-sql-wait-for-operation source: cloud-sql-admin-source + multiplier: 4 + clone_instance: + kind: cloud-sql-clone-instance + source: cloud-sql-admin-source toolsets: cloud_sql_mysql_admin_tools: @@ -49,3 +53,4 @@ toolsets: - list_databases - create_user - wait_for_operation + - clone_instance diff --git a/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml b/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml index 0a6008eadc..63a73730b7 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-mysql.yaml @@ -32,16 +32,9 @@ tools: source: cloud-sql-mysql-source description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema. get_query_plan: - kind: mysql-sql + kind: mysql-get-query-plan source: cloud-sql-mysql-source description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones." - statement: | - EXPLAIN FORMAT=JSON {{.sql_statement}}; - templateParameters: - - name: sql_statement - type: string - description: "the SQL statement to explain" - required: true list_tables: kind: mysql-list-tables source: cloud-sql-mysql-source diff --git a/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml b/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml index cb4c503877..dffac3dc1b 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-postgres-admin.yaml @@ -39,6 +39,10 @@ tools: wait_for_operation: kind: cloud-sql-wait-for-operation source: cloud-sql-admin-source + multiplier: 4 + clone_instance: + kind: cloud-sql-clone-instance + source: cloud-sql-admin-source postgres_upgrade_precheck: kind: postgres-upgrade-precheck source: cloud-sql-admin-source @@ -53,3 +57,4 @@ toolsets: - create_user - wait_for_operation - postgres_upgrade_precheck + - clone_instance diff --git a/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml b/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml index 01e057166f..bd4cb759a4 100644 --- a/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml +++ b/internal/prebuiltconfigs/tools/cloud-sql-postgres.yaml @@ -177,7 +177,7 @@ tools: list_schemas: kind: postgres-list-schemas source: cloudsql-pg-source - + database_overview: kind: postgres-database-overview source: cloudsql-pg-source @@ -201,6 +201,30 @@ tools: get_column_cardinality: kind: postgres-get-column-cardinality source: cloudsql-pg-source + + list_table_stats: + kind: postgres-list-table-stats + source: cloudsql-pg-source + + list_publication_tables: + kind: postgres-list-publication-tables + source: cloudsql-pg-source + + list_tablespaces: + kind: postgres-list-tablespaces + source: cloudsql-pg-source + + list_pg_settings: + kind: postgres-list-pg-settings + source: cloudsql-pg-source + + list_database_stats: + kind: postgres-list-database-stats + source: cloudsql-pg-source + + list_roles: + kind: postgres-list-roles + source: cloudsql-pg-source toolsets: cloud_sql_postgres_database_tools: @@ -226,3 +250,9 @@ toolsets: - replication_stats - list_query_stats - get_column_cardinality + - list_publication_tables + - list_tablespaces + - list_pg_settings + - list_database_stats + - list_roles + - list_table_stats diff --git a/internal/prebuiltconfigs/tools/looker-conversational-analytics.yaml b/internal/prebuiltconfigs/tools/looker-conversational-analytics.yaml index ede07d3279..4a51cf38fd 100644 --- a/internal/prebuiltconfigs/tools/looker-conversational-analytics.yaml +++ b/internal/prebuiltconfigs/tools/looker-conversational-analytics.yaml @@ -28,33 +28,38 @@ tools: ask_data_insights: kind: looker-conversational-analytics source: looker-source - annotations: - readOnlyHint: true description: | - Use this tool to perform data analysis, get insights, - or answer complex questions about the contents of specific - Looker explores. + Use this tool to ask questions about your data using the Looker Conversational + Analytics API. You must provide a natural language query and a list of + 1 to 5 model and explore combinations (e.g. [{'model': 'the_model', 'explore': 'the_explore'}]). + Use the 'get_models' and 'get_explores' tools to discover available models and explores. get_models: kind: looker-get-models source: looker-source - annotations: - readOnlyHint: true description: | - The get_models tool retrieves the list of LookML models in the Looker system. + get_models Tool - It takes no parameters. + This tool retrieves a list of available LookML models in the Looker instance. + LookML models define the data structure and relationships that users can query. + The output includes details like the model's `name` and `label`, which are + essential for subsequent calls to tools like `get_explores` or `query`. + + This tool takes no parameters. get_explores: kind: looker-get-explores source: looker-source - annotations: - readOnlyHint: true description: | - The get_explores tool retrieves the list of explores defined in a LookML model - in the Looker system. + get_explores Tool - It takes one parameter, the model_name looked up from get_models. + This tool retrieves a list of explores defined within a specific LookML model. + Explores represent a curated view of your data, typically joining several + tables together to allow for focused analysis on a particular subject area. + The output provides details like the explore's `name` and `label`. + + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. toolsets: looker_conversational_analytics_tools: diff --git a/internal/prebuiltconfigs/tools/looker.yaml b/internal/prebuiltconfigs/tools/looker.yaml index b46fa26769..442cd11106 100644 --- a/internal/prebuiltconfigs/tools/looker.yaml +++ b/internal/prebuiltconfigs/tools/looker.yaml @@ -29,155 +29,152 @@ tools: get_models: kind: looker-get-models source: looker-source - annotations: - readOnlyHint: true description: | - The get_models tool retrieves the list of LookML models in the Looker system. + This tool retrieves a list of available LookML models in the Looker instance. + LookML models define the data structure and relationships that users can query. + The output includes details like the model's `name` and `label`, which are + essential for subsequent calls to tools like `get_explores` or `query`. - It takes no parameters. + This tool takes no parameters. get_explores: kind: looker-get-explores source: looker-source - annotations: - readOnlyHint: true description: | - The get_explores tool retrieves the list of explores defined in a LookML model - in the Looker system. + This tool retrieves a list of explores defined within a specific LookML model. + Explores represent a curated view of your data, typically joining several + tables together to allow for focused analysis on a particular subject area. + The output provides details like the explore's `name` and `label`. - It takes one parameter, the model_name looked up from get_models. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. get_dimensions: kind: looker-get-dimensions source: looker-source - annotations: - readOnlyHint: true description: | - The get_dimensions tool retrieves the list of dimensions defined in - an explore. + This tool retrieves a list of dimensions defined within a specific Looker explore. + Dimensions are non-aggregatable attributes or characteristics of your data + (e.g., product name, order date, customer city) that can be used for grouping, + filtering, or segmenting query results. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. - If this returns a suggestions field for a dimension, the contents of suggestions - can be used as filters for this field. If this returns a suggest_explore and - suggest_dimension, a query against that explore and dimension can be used to find - valid filters for this field. + Output Details: + - If a dimension includes a `suggestions` field, its contents are valid values + that can be used directly as filters for that dimension. + - If a `suggest_explore` and `suggest_dimension` are provided, you can query + that specified explore and dimension to retrieve a list of valid filter values. get_measures: kind: looker-get-measures source: looker-source - annotations: - readOnlyHint: true description: | - The get_measures tool retrieves the list of measures defined in - an explore. + This tool retrieves a list of measures defined within a specific Looker explore. + Measures are aggregatable metrics (e.g., total sales, average price, count of users) + that are used for calculations and quantitative analysis in your queries. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. - If this returns a suggestions field for a measure, the contents of suggestions - can be used as filters for this field. If this returns a suggest_explore and - suggest_dimension, a query against that explore and dimension can be used to find - valid filters for this field. + Output Details: + - If a measure includes a `suggestions` field, its contents are valid values + that can be used directly as filters for that measure. + - If a `suggest_explore` and `suggest_dimension` are provided, you can query + that specified explore and dimension to retrieve a list of valid filter values. get_filters: kind: looker-get-filters source: looker-source - annotations: - readOnlyHint: true description: | - The get_filters tool retrieves the list of filters defined in - an explore. + This tool retrieves a list of "filter-only fields" defined within a specific + Looker explore. These are special fields defined in LookML specifically to + create user-facing filter controls that do not directly affect the `GROUP BY` + clause of the SQL query. They are often used in conjunction with liquid templating + to create dynamic queries. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Note: Regular dimensions and measures can also be used as filters in a query. + This tool *only* returns fields explicitly defined as `filter:` in LookML. + + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. get_parameters: kind: looker-get-parameters source: looker-source - annotations: - readOnlyHint: true description: | - The get_parameters tool retrieves the list of parameters defined in - an explore. + This tool retrieves a list of parameters defined within a specific Looker explore. + LookML parameters are dynamic input fields that allow users to influence query + behavior without directly modifying the underlying LookML. They are often used + with `liquid` templating to create flexible dashboards and reports, enabling + users to choose dimensions, measures, or other query components at runtime. - It takes two parameters, the model_name looked up from get_models and the - explore_name looked up from get_explores. + Parameters: + - model_name (required): The name of the LookML model, obtained from `get_models`. + - explore_name (required): The name of the explore within the model, obtained from `get_explores`. query: kind: looker-query source: looker-source - annotations: - readOnlyHint: true description: | - Query Tool + This tool runs a query against a LookML model and returns the results in JSON format. - This tool is used to run a query against the LookML model. The - model, explore, and fields list must be specified. Pivots, - filters and sorts are optional. + Required Parameters: + - model_name: The name of the LookML model (from `get_models`). + - explore_name: The name of the explore (from `get_explores`). + - fields: A list of field names (dimensions, measures, filters, or parameters) to include in the query. - The model can be found from the get_models tool. The explore - can be found from the get_explores tool passing in the model. - The fields can be found from the get_dimensions, get_measures, - get_filters, and get_parameters tools, passing in the model - and the explore. + Optional Parameters: + - pivots: A list of fields to pivot the results by. These fields must also be included in the `fields` list. + - filters: A map of filter expressions, e.g., `{"view.field": "value", "view.date": "7 days"}`. + - Do not quote field names. + - Use `not null` instead of `-NULL`. + - If a value contains a comma, enclose it in single quotes (e.g., "'New York, NY'"). + - sorts: A list of fields to sort by, optionally including direction (e.g., `["view.field desc"]`). + - limit: Row limit (default 500). Use "-1" for unlimited. + - query_timezone: specific timezone for the query (e.g. `America/Los_Angeles`). - Provide a model_id and explore_name, then a list - of fields. Optionally a list of pivots can be provided. - The pivots must also be included in the fields list. - - Filters are provided as a map of {"field.id": "condition", - "field.id2": "condition2", ...}. Do not put the field.id in - quotes. Filter expressions can be found at - https://cloud.google.com/looker/docs/filter-expressions. There - is one mistake in that, however, Use `not null` instead of `-NULL`. - If the condition is a string that contains a comma, use a second - set of quotes. For example, {"user.city": "'New York, NY'"}. - - Sorts can be specified like [ "field.id desc 0" ]. - - An optional row limit can be added. If not provided the limit - will default to 500. "-1" can be specified for unlimited. - - An optional query timezone can be added. The query_timezone to - will default to that of the workstation where this MCP server - is running, or Etc/UTC if that can't be determined. Not all - models support custom timezones. - - The result of the query tool is JSON + Note: Use `get_dimensions`, `get_measures`, `get_filters`, and `get_parameters` to find valid fields. query_sql: kind: looker-query-sql source: looker-source - annotations: - readOnlyHint: true description: | - Query SQL Tool + This tool generates the underlying SQL query that Looker would execute + against the database for a given set of parameters. It is useful for + understanding how Looker translates a request into SQL. - This tool is used to generate the SQL that Looker would - run against the underlying database. The parameters are - the same as the query tool. + Parameters: + All parameters for this tool are identical to those of the `query` tool. + This includes `model_name`, `explore_name`, `fields` (required), + and optional parameters like `pivots`, `filters`, `sorts`, `limit`, and `query_timezone`. - The result of the query sql tool is SQL text. + Output: + The result of this tool is the raw SQL text. query_url: kind: looker-query-url source: looker-source - annotations: - readOnlyHint: true description: | - Query URL Tool + This tool generates a shareable URL for a Looker query, allowing users to + explore the query further within the Looker UI. It returns the generated URL, + along with the `query_id` and `slug`. - This tool is used to generate the URL of a query in Looker. - The user can then explore the query further inside Looker. - The tool also returns the query_id and slug. The parameters - are the same as the query tool with an additional vis_config - parameter. + Parameters: + All query parameters (e.g., `model_name`, `explore_name`, `fields`, `pivots`, + `filters`, `sorts`, `limit`, `query_timezone`) are the same as the `query` tool. - The vis_config is optional. If provided, it will be used to - control the default visualization for the query. Here are - some notes on making visualizations. + Additionally, it accepts an optional `vis_config` parameter: + - vis_config (optional): A JSON object that controls the default visualization + settings for the generated query. + + vis_config Details: + The `vis_config` object supports a wide range of properties for various chart types. + Here are some notes on making visualizations. ### Cartesian Charts (Area, Bar, Column, Line, Scatter) @@ -616,333 +613,433 @@ tools: get_looks: kind: looker-get-looks source: looker-source - annotations: - readOnlyHint: true description: | - get_looks Tool + This tool searches for saved Looks (pre-defined queries and visualizations) + in a Looker instance. It returns a list of JSON objects, each representing a Look. - This tool is used to search for saved looks in a Looker instance. - String search params use case-insensitive matching. String search - params can contain % and '_' as SQL LIKE pattern match wildcard - expressions. example="dan%" will match "danger" and "Danzig" but - not "David" example="D_m%" will match "Damage" and "dump". + Search Parameters: + - title (optional): Filter by Look title (supports wildcards). + - folder_id (optional): Filter by the ID of the folder where the Look is saved. + - user_id (optional): Filter by the ID of the user who created the Look. + - description (optional): Filter by description content (supports wildcards). + - id (optional): Filter by specific Look ID. + - limit (optional): Maximum number of results to return. Defaults to a system limit. + - offset (optional): Starting point for pagination. - Most search params can accept "IS NULL" and "NOT NULL" as special - expressions to match or exclude (respectively) rows where the - column is null. - - The limit and offset are used to paginate the results. - - The result of the get_looks tool is a list of json objects. + String Search Behavior: + - Case-insensitive matching. + - Supports SQL LIKE pattern match wildcards: + - `%`: Matches any sequence of zero or more characters. (e.g., `"dan%"` matches "danger", "Danzig") + - `_`: Matches any single character. (e.g., `"D_m%"` matches "Damage", "dump") + - Special expressions for null checks: + - `"IS NULL"`: Matches Looks where the field is null. + - `"NOT NULL"`: Excludes Looks where the field is null. run_look: kind: looker-run-look source: looker-source - annotations: - readOnlyHint: true description: | - run_look Tool + This tool executes the query associated with a saved Look and + returns the resulting data in a JSON structure. - This tool runs the query associated with a look and returns - the data in a JSON structure. It accepts the look_id as the - parameter. + Parameters: + - look_id (required): The unique identifier of the Look to run, + typically obtained from the `get_looks` tool. + + Output: + The query results are returned as a JSON object. make_look: kind: looker-make-look source: looker-source - annotations: - readOnlyHint: false description: | - make_look Tool + This tool creates a new Look (saved query with visualization) in Looker. + The Look will be saved in the user's personal folder, and its name must be unique. - This tool creates a new look in Looker, using the query - parameters and the vis_config specified. + Required Parameters: + - title: A unique title for the new Look. + - description: A brief description of the Look's purpose. + - model_name: The name of the LookML model (from `get_models`). + - explore_name: The name of the explore (from `get_explores`). + - fields: A list of field names (dimensions, measures, filters, or parameters) to include in the query. - Most of the parameters are the same as the query_url - tool. In addition, there is a title and a description - that must be provided. + Optional Parameters: + - pivots, filters, sorts, limit, query_timezone: These parameters are identical + to those described for the `query` tool. + - vis_config: A JSON object defining the visualization settings for the Look. + The structure and options are the same as for the `query_url` tool's `vis_config`. - The newly created look will be created in the user's - personal folder in looker. The look name must be unique. - - The result is a json document with a link to the newly - created look. + Output: + A JSON object containing a link (`url`) to the newly created Look, along with its `id` and `slug`. get_dashboards: kind: looker-get-dashboards source: looker-source - annotations: - readOnlyHint: true description: | - get_dashboards Tool + This tool searches for saved dashboards in a Looker instance. It returns a list of JSON objects, each representing a dashboard. - This tool is used to search for saved dashboards in a Looker instance. - String search params use case-insensitive matching. String search - params can contain % and '_' as SQL LIKE pattern match wildcard - expressions. example="dan%" will match "danger" and "Danzig" but - not "David" example="D_m%" will match "Damage" and "dump". - Most search params can accept "IS NULL" and "NOT NULL" as special - expressions to match or exclude (respectively) rows where the - column is null. + Search Parameters: + - title (optional): Filter by dashboard title (supports wildcards). + - folder_id (optional): Filter by the ID of the folder where the dashboard is saved. + - user_id (optional): Filter by the ID of the user who created the dashboard. + - description (optional): Filter by description content (supports wildcards). + - id (optional): Filter by specific dashboard ID. + - limit (optional): Maximum number of results to return. Defaults to a system limit. + - offset (optional): Starting point for pagination. - The limit and offset are used to paginate the results. - - The result of the get_dashboards tool is a list of json objects. + String Search Behavior: + - Case-insensitive matching. + - Supports SQL LIKE pattern match wildcards: + - `%`: Matches any sequence of zero or more characters. (e.g., `"finan%"` matches "financial", "finance") + - `_`: Matches any single character. (e.g., `"s_les"` matches "sales") + - Special expressions for null checks: + - `"IS NULL"`: Matches dashboards where the field is null. + - `"NOT NULL"`: Excludes dashboards where the field is null. run_dashboard: kind: looker-run-dashboard source: looker-source - annotations: - readOnlyHint: true description: | - run_dashboard Tool + This tool executes the queries associated with each tile in a specified dashboard + and returns the aggregated data in a JSON structure. - This tools runs the query associated with each tile in a dashboard - and returns the data in a JSON structure. It accepts the dashboard_id - as the parameter. + Parameters: + - dashboard_id (required): The unique identifier of the dashboard to run, + typically obtained from the `get_dashboards` tool. + + Output: + The data from all dashboard tiles is returned as a JSON object. make_dashboard: kind: looker-make-dashboard source: looker-source - annotations: - readOnlyHint: false description: | - make_dashboard Tool + This tool creates a new, empty dashboard in Looker. Dashboards are stored + in the user's personal folder, and the dashboard name must be unique. + After creation, use `add_dashboard_filter` to add filters and + `add_dashboard_element` to add content tiles. - This tool creates a new dashboard in Looker. The dashboard is - initially empty and the add_dashboard_element tool is used to - add content to the dashboard. + Required Parameters: + - title (required): A unique title for the new dashboard. + - description (required): A brief description of the dashboard's purpose. - The newly created dashboard will be created in the user's - personal folder in looker. The dashboard name must be unique. - - The result is a json document with a link to the newly - created dashboard and the id of the dashboard. Use the id - when calling add_dashboard_element. + Output: + A JSON object containing a link (`url`) to the newly created dashboard and + its unique `id`. This `dashboard_id` is crucial for subsequent calls to + `add_dashboard_filter` and `add_dashboard_element`. add_dashboard_element: kind: looker-add-dashboard-element source: looker-source - annotations: - readOnlyHint: false description: | - add_dashboard_element Tool + This tool creates a new tile (element) within an existing Looker dashboard. + Tiles are added in the order this tool is called for a given `dashboard_id`. - This tool creates a new tile in a Looker dashboard using - the query parameters and the vis_config specified. + CRITICAL ORDER OF OPERATIONS: + 1. Create the dashboard using `make_dashboard`. + 2. Add any dashboard-level filters using `add_dashboard_filter`. + 3. Then, add elements (tiles) using this tool. - Most of the parameters are the same as the query_url - tool. In addition, there is a title that may be provided. - The dashboard_id must be specified. That is obtained - from calling make_dashboard. + Required Parameters: + - dashboard_id: The ID of the target dashboard, obtained from `make_dashboard`. + - model_name, explore_name, fields: These query parameters are inherited + from the `query` tool and are required to define the data for the tile. - This tool can be called many times for one dashboard_id - and the resulting tiles will be added in order. + Optional Parameters: + - title: An optional title for the dashboard tile. + - pivots, filters, sorts, limit, query_timezone: These query parameters are + inherited from the `query` tool and can be used to customize the tile's query. + - vis_config: A JSON object defining the visualization settings for this tile. + The structure and options are the same as for the `query_url` tool's `vis_config`. + Connecting to Dashboard Filters: + A dashboard element can be connected to one or more dashboard filters (created with + `add_dashboard_filter`). To do this, specify the `name` of the dashboard filter + and the `field` from the element's query that the filter should apply to. + The format for specifying the field is `view_name.field_name`. + + add_dashboard_filter: + kind: looker-add-dashboard-filter + source: looker-source + description: | + This tool adds a filter to a Looker dashboard. + + CRITICAL ORDER OF OPERATIONS: + 1. Create a dashboard using `make_dashboard`. + 2. Add all desired filters using this tool (`add_dashboard_filter`). + 3. Finally, add dashboard elements (tiles) using `add_dashboard_element`. + + Parameters: + - dashboard_id (required): The ID from `make_dashboard`. + - name (required): A unique internal identifier for the filter. You will use this `name` later in `add_dashboard_element` to bind tiles to this filter. + - title (required): The label displayed to users in the UI. + - flter_type (required): One of `date_filter`, `number_filter`, `string_filter`, or `field_filter`. + - default_value (optional): The initial value for the filter. + + Field Filters (`flter_type: field_filter`): + If creating a field filter, you must also provide: + - model + - explore + - dimension + The filter will inherit suggestions and type information from this LookML field. + + generate_embed_url: + kind: looker-generate-embed-url + source: looker-source + description: | + This tool generates a signed, private embed URL for specific Looker content, + allowing users to access it directly. + + Parameters: + - type (required): The type of content to embed. Common values include: + - `dashboards` + - `looks` + - `explore` + - id (required): The unique identifier for the content. + - For dashboards and looks, use the numeric ID (e.g., "123"). + - For explores, use the format "model_name/explore_name". + health_pulse: kind: looker-health-pulse source: looker-source - annotations: - readOnlyHint: true description: | - health-pulse Tool + This tool performs various health checks on a Looker instance. - This tool takes the pulse of a Looker instance by taking - one of the following actions: - 1. `check_db_connections`, - 2. `check_dashboard_performance`, - 3. `check_dashboard_errors`, - 4. `check_explore_performance`, - 5. `check_schedule_failures`, or - 6. `check_legacy_features` - - The `check_legacy_features` action is only available in Looker Core. If - it is called on a Looker Core instance, you will get a notice. That notice - should not be reported as an error. + Parameters: + - action (required): Specifies the type of health check to perform. + Choose one of the following: + - `check_db_connections`: Verifies database connectivity. + - `check_dashboard_performance`: Assesses dashboard loading performance. + - `check_dashboard_errors`: Identifies errors within dashboards. + - `check_explore_performance`: Evaluates explore query performance. + - `check_schedule_failures`: Reports on failed scheduled deliveries. + - `check_legacy_features`: Checks for the usage of legacy features. + + Note on `check_legacy_features`: + This action is exclusively available in Looker Core instances. If invoked + on a non-Looker Core instance, it will return a notice rather than an error. + This notice should be considered normal behavior and not an indication of an issue. health_analyze: kind: looker-health-analyze source: looker-source - annotations: - readOnlyHint: true description: | - health-analyze Tool + This tool calculates the usage statistics for Looker projects, models, and explores. - This tool calculates the usage of projects, models and explores. + Parameters: + - action (required): The type of resource to analyze. Can be `"projects"`, `"models"`, or `"explores"`. + - project (optional): The specific project ID to analyze. + - model (optional): The specific model name to analyze. Requires `project` if used without `explore`. + - explore (optional): The specific explore name to analyze. Requires `model` if used. + - timeframe (optional): The lookback period in days for usage data. Defaults to `90` days. + - min_queries (optional): The minimum number of queries for a resource to be considered active. Defaults to `1`. - It accepts 6 parameters: - 1. `action`: can be "projects", "models", or "explores" - 2. `project`: the project to analyze (optional) - 3. `model`: the model to analyze (optional) - 4. `explore`: the explore to analyze (optional) - 5. `timeframe`: the lookback period in days, default is 90 - 6. `min_queries`: the minimum number of queries to consider a resource as active, default is 1 + Output: + The result is a JSON object containing usage metrics for the specified resources. health_vacuum: kind: looker-health-vacuum source: looker-source - annotations: - readOnlyHint: true description: | - health-vacuum Tool + This tool identifies and suggests LookML models or explores that can be + safely removed due to inactivity or low usage. - This tool suggests models or explores that can removed - because they are unused. + Parameters: + - action (required): The type of resource to analyze for removal candidates. Can be `"models"` or `"explores"`. + - project (optional): The specific project ID to consider. + - model (optional): The specific model name to consider. Requires `project` if used without `explore`. + - explore (optional): The specific explore name to consider. Requires `model` if used. + - timeframe (optional): The lookback period in days to assess usage. Defaults to `90` days. + - min_queries (optional): The minimum number of queries for a resource to be considered active. Defaults to `1`. - It accepts 6 parameters: - 1. `action`: can be "models" or "explores" - 2. `project`: the project to vacuum (optional) - 3. `model`: the model to vacuum (optional) - 4. `explore`: the explore to vacuum (optional) - 5. `timeframe`: the lookback period in days, default is 90 - 6. `min_queries`: the minimum number of queries to consider a resource as active, default is 1 - - The result is a list of objects that are candidates for deletion. + Output: + A JSON array of objects, each representing a model or explore that is a candidate for deletion due to low usage. dev_mode: kind: looker-dev-mode source: looker-source - annotations: - readOnlyHint: true description: | - dev_mode Tool + This tool allows toggling the Looker IDE session between Development Mode and Production Mode. + Development Mode enables making and testing changes to LookML projects. - Passing true to this tool switches the session to dev mode. Passing false to this tool switches the - session to production mode. + Parameters: + - enable (required): A boolean value. + - `true`: Switches the current session to Development Mode. + - `false`: Switches the current session to Production Mode. get_projects: kind: looker-get-projects source: looker-source - annotations: - readOnlyHint: true description: | - get_projects Tool + This tool retrieves a list of all LookML projects available on the Looker instance. + It is useful for identifying projects before performing actions like retrieving + project files or making modifications. - This tool returns the project_id and project_name for - all the LookML projects on the looker instance. + Parameters: + This tool takes no parameters. + + Output: + A JSON array of objects, each containing the `project_id` and `project_name` + for a LookML project. get_project_files: kind: looker-get-project-files source: looker-source - annotations: - readOnlyHint: true description: | - get_project_files Tool + This tool retrieves a list of all LookML files within a specified project, + providing details about each file. - Given a project_id this tool returns the details about - the LookML files that make up that project. + Parameters: + - project_id (required): The unique ID of the LookML project, obtained from `get_projects`. + + Output: + A JSON array of objects, each representing a LookML file and containing + details such as `path`, `id`, `type`, and `git_status`. get_project_file: kind: looker-get-project-file source: looker-source - annotations: - readOnlyHint: true description: | - get_project_file Tool + This tool retrieves the raw content of a specific LookML file from within a project. - Given a project_id and a file path within the project, this tool returns - the contents of the LookML file. + Parameters: + - project_id (required): The unique ID of the LookML project, obtained from `get_projects`. + - file_path (required): The path to the LookML file within the project, + typically obtained from `get_project_files`. + + Output: + The raw text content of the specified LookML file. create_project_file: kind: looker-create-project-file source: looker-source - annotations: - readOnlyHint: false description: | - create_project_file Tool + This tool creates a new LookML file within a specified project, populating + it with the provided content. - Given a project_id and a file path within the project, as well as the content - of a LookML file, this tool will create a new file within the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The desired path and filename for the new file within the project. + - content (required): The full LookML content to write into the new file. + + Output: + A confirmation message upon successful file creation. update_project_file: kind: looker-update-project-file source: looker-source - annotations: - destructiveHint: true - readOnlyHint: false description: | - update_project_file Tool + This tool modifies the content of an existing LookML file within a specified project. - Given a project_id and a file path within the project, as well as the content - of a LookML file, this tool will modify the file within the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The exact path to the LookML file to modify within the project. + - content (required): The new, complete LookML content to overwrite the existing file. + + Output: + A confirmation message upon successful file modification. delete_project_file: kind: looker-delete-project-file source: looker-source - annotations: - destructiveHint: true - readOnlyHint: false description: | - delete_project_file Tool + This tool permanently deletes a specified LookML file from within a project. + Use with caution, as this action cannot be undone through the API. - Given a project_id and a file path within the project, this tool will delete - the file from the project. + Prerequisite: The Looker session must be in Development Mode. Use `dev_mode: true` first. - This tool must be called after the dev_mode tool has changed the session to - dev mode. + Parameters: + - project_id (required): The unique ID of the LookML project. + - file_path (required): The exact path to the LookML file to delete within the project. + + Output: + A confirmation message upon successful file deletion. get_connections: kind: looker-get-connections source: looker-source - annotations: - readOnlyHint: true description: | - get_connections Tool + This tool retrieves a list of all database connections configured in the Looker system. - This tool will list all the connections available in the Looker system, as - well as the dialect name, the default schema, the database if applicable, - and whether the connection supports multiple databases. + Parameters: + This tool takes no parameters. + + Output: + A JSON array of objects, each representing a database connection and including details such as: + - `name`: The connection's unique identifier. + - `dialect`: The database dialect (e.g., "mysql", "postgresql", "bigquery"). + - `default_schema`: The default schema for the connection. + - `database`: The associated database name (if applicable). + - `supports_multiple_databases`: A boolean indicating if the connection can access multiple databases. get_connection_schemas: kind: looker-get-connection-schemas source: looker-source - annotations: - readOnlyHint: true description: | - get_connection_schemas Tool + This tool retrieves a list of database schemas available through a specified + Looker connection. - This tool will list the schemas available from a connection, filtered by - an optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - database (optional): An optional database name to filter the schemas. + Only applicable for connections that support multiple databases. + + Output: + A JSON array of strings, where each string is the name of an available schema. get_connection_databases: kind: looker-get-connection-databases source: looker-source - annotations: - readOnlyHint: true description: | - get_connection_databases Tool + This tool retrieves a list of databases available through a specified Looker connection. + This is only applicable for connections that support multiple databases. + Use `get_connections` to check if a connection supports multiple databases. - This tool will list the databases available from a connection if the connection - supports multiple databases. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + + Output: + A JSON array of strings, where each string is the name of an available database. + If the connection does not support multiple databases, an empty list or an error will be returned. get_connection_tables: kind: looker-get-connection-tables source: looker-source - annotations: - readOnlyHint: true description: | - get_connection_tables Tool + This tool retrieves a list of tables available within a specified database schema + through a Looker connection. - This tool will list the tables available from a connection, filtered by the - schema name and optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - schema (required): The name of the schema to list tables from, obtained from `get_connection_schemas`. + - database (optional): The name of the database to filter by. Only applicable for connections + that support multiple databases (check with `get_connections`). + + Output: + A JSON array of strings, where each string is the name of an available table. get_connection_table_columns: kind: looker-get-connection-table-columns source: looker-source - annotations: - readOnlyHint: true description: | - get_connection_table_columns Tool + This tool retrieves a list of columns for one or more specified tables within a + given database schema and connection. - This tool will list the columns available from a connection, for all the tables - given in a comma separated list of table names, filtered by the - schema name and optional database name. + Parameters: + - connection_name (required): The name of the database connection, obtained from `get_connections`. + - schema (required): The name of the schema where the tables reside, obtained from `get_connection_schemas`. + - tables (required): A comma-separated string of table names for which to retrieve columns + (e.g., "users,orders,products"), obtained from `get_connection_tables`. + - database (optional): The name of the database to filter by. Only applicable for connections + that support multiple databases (check with `get_connections`). + + Output: + A JSON array of objects, where each object represents a column and contains details + such as `table_name`, `column_name`, `data_type`, and `is_nullable`. toolsets: @@ -963,6 +1060,8 @@ toolsets: - run_dashboard - make_dashboard - add_dashboard_element + - add_dashboard_filter + - generate_embed_url - health_pulse - health_analyze - health_vacuum diff --git a/internal/prebuiltconfigs/tools/mssql.yaml b/internal/prebuiltconfigs/tools/mssql.yaml index 350219af36..f604899596 100644 --- a/internal/prebuiltconfigs/tools/mssql.yaml +++ b/internal/prebuiltconfigs/tools/mssql.yaml @@ -15,8 +15,8 @@ sources: mssql-source: kind: mssql - host: ${MSSQL_HOST} - port: ${MSSQL_PORT} + host: ${MSSQL_HOST:localhost} + port: ${MSSQL_PORT:1433} database: ${MSSQL_DATABASE} user: ${MSSQL_USER} password: ${MSSQL_PASSWORD} diff --git a/internal/prebuiltconfigs/tools/mysql.yaml b/internal/prebuiltconfigs/tools/mysql.yaml index 9f85de3642..d3068550eb 100644 --- a/internal/prebuiltconfigs/tools/mysql.yaml +++ b/internal/prebuiltconfigs/tools/mysql.yaml @@ -36,16 +36,9 @@ tools: source: mysql-source description: Lists top N (default 10) ongoing queries from processlist and innodb_trx, ordered by execution time in descending order. Returns detailed information of those queries in json format, including process id, query, transaction duration, transaction wait duration, process time, transaction state, process state, username with host, transaction rows locked, transaction rows modified, and db schema. get_query_plan: - kind: mysql-sql + kind: mysql-get-query-plan source: mysql-source description: "Provide information about how MySQL executes a SQL statement. Common use cases include: 1) analyze query plan to improve its performance, and 2) determine effectiveness of existing indexes and evalueate new ones." - statement: | - EXPLAIN FORMAT=JSON {{.sql_statement}}; - templateParameters: - - name: sql_statement - type: string - description: "the SQL statement to explain" - required: true list_tables: kind: mysql-list-tables source: mysql-source diff --git a/internal/prebuiltconfigs/tools/postgres.yaml b/internal/prebuiltconfigs/tools/postgres.yaml index 38bb02bb8c..91360a5b7e 100644 --- a/internal/prebuiltconfigs/tools/postgres.yaml +++ b/internal/prebuiltconfigs/tools/postgres.yaml @@ -176,7 +176,7 @@ tools: list_schemas: kind: postgres-list-schemas source: postgresql-source - + database_overview: kind: postgres-database-overview source: postgresql-source @@ -201,6 +201,30 @@ tools: kind: postgres-get-column-cardinality source: postgresql-source + list_table_stats: + kind: postgres-list-table-stats + source: postgresql-source + + list_publication_tables: + kind: postgres-list-publication-tables + source: postgresql-source + + list_tablespaces: + kind: postgres-list-tablespaces + source: postgresql-source + + list_pg_settings: + kind: postgres-list-pg-settings + source: postgresql-source + + list_database_stats: + kind: postgres-list-database-stats + source: postgresql-source + + list_roles: + kind: postgres-list-roles + source: postgresql-source + toolsets: postgres_database_tools: - execute_sql @@ -225,4 +249,9 @@ toolsets: - replication_stats - list_query_stats - get_column_cardinality - + - list_publication_tables + - list_tablespaces + - list_pg_settings + - list_database_stats + - list_roles + - list_table_stats diff --git a/internal/prebuiltconfigs/tools/serverless-spark.yaml b/internal/prebuiltconfigs/tools/serverless-spark.yaml index 7d78b18a95..ed67bb3911 100644 --- a/internal/prebuiltconfigs/tools/serverless-spark.yaml +++ b/internal/prebuiltconfigs/tools/serverless-spark.yaml @@ -28,9 +28,17 @@ tools: cancel_batch: kind: serverless-spark-cancel-batch source: serverless-spark-source + create_pyspark_batch: + kind: serverless-spark-create-pyspark-batch + source: serverless-spark-source + create_spark_batch: + kind: serverless-spark-create-spark-batch + source: serverless-spark-source toolsets: serverless_spark_tools: - list_batches - get_batch - cancel_batch + - create_pyspark_batch + - create_spark_batch diff --git a/internal/prebuiltconfigs/tools/spanner.yaml b/internal/prebuiltconfigs/tools/spanner.yaml index db15412fd8..68839a09dd 100644 --- a/internal/prebuiltconfigs/tools/spanner.yaml +++ b/internal/prebuiltconfigs/tools/spanner.yaml @@ -35,10 +35,16 @@ tools: list_tables: kind: spanner-list-tables source: spanner-source - description: "Lists detailed schema information (object type, columns, constraints, indexes) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas." + description: "Lists detailed schema information (object type, columns, constraints, indexes) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas. The output can be 'simple' (table names only) or 'detailed' (full schema)." + + list_graphs: + kind: spanner-list-graphs + source: spanner-source + description: "Lists detailed graph schema information (node tables, edge tables, labels and property declarations) as JSON for user-created graphs. Filters by a comma-separated list of graph names. If names are omitted, lists all graphs. The output can be 'simple' (graph names only) or 'detailed' (full schema)." toolsets: spanner-database-tools: - execute_sql - execute_sql_dql - list_tables + - list_graphs diff --git a/internal/prompts/promptsets.go b/internal/prompts/promptsets.go index 14b14083c1..aee131ffe7 100644 --- a/internal/prompts/promptsets.go +++ b/internal/prompts/promptsets.go @@ -46,7 +46,7 @@ func (t PromptsetConfig) Initialize(serverVersion string, promptsMap map[string] var promptset Promptset promptset.Name = t.Name if !tools.IsValidName(promptset.Name) { - return promptset, fmt.Errorf("invalid promptset name: %s", t) + return promptset, fmt.Errorf("invalid promptset name: %s", promptset.Name) } promptset.Prompts = make([]*Prompt, 0, len(t.PromptNames)) promptset.McpManifest = make([]McpManifest, 0, len(t.PromptNames)) @@ -57,7 +57,7 @@ func (t PromptsetConfig) Initialize(serverVersion string, promptsMap map[string] for _, promptName := range t.PromptNames { prompt, ok := promptsMap[promptName] if !ok { - return promptset, fmt.Errorf("prompt does not exist: %s", t) + return promptset, fmt.Errorf("prompt does not exist: %s", promptName) } promptset.Prompts = append(promptset.Prompts, &prompt) promptset.Manifest.PromptsManifest[promptName] = prompt.Manifest() diff --git a/internal/server/api.go b/internal/server/api.go index 22068e3015..c03a214168 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -172,7 +172,14 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { accessToken := tools.AccessToken(r.Header.Get("Authorization")) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization() { + clientAuth, err := tool.RequiresClientAuthorization(s.ResourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + s.logger.DebugContext(ctx, errMsg.Error()) + _ = render.Render(w, r, newErrResponse(errMsg, http.StatusNotFound)) + return + } + if clientAuth { if accessToken == "" { err = fmt.Errorf("tool requires client authorization but access token is missing from the request header") s.logger.DebugContext(ctx, err.Error()) @@ -239,7 +246,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { } s.logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) - res, err := tool.Invoke(ctx, params, accessToken) + res, err := tool.Invoke(ctx, s.ResourceMgr, params, accessToken) // Determine what error to return to the users. if err != nil { @@ -255,7 +262,7 @@ func toolInvokeHandler(s *Server, w http.ResponseWriter, r *http.Request) { } if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { - if tool.RequiresClientAuthorization() { + if clientAuth { // Propagate the original 401/403 error. s.logger.DebugContext(ctx, fmt.Sprintf("error invoking tool. Client credentials lack authorization to the source: %v", err)) _ = render.Render(w, r, newErrResponse(err, statusCode)) diff --git a/internal/server/common_test.go b/internal/server/common_test.go index 2722da61fb..3953e1c7bc 100644 --- a/internal/server/common_test.go +++ b/internal/server/common_test.go @@ -26,6 +26,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -49,7 +50,7 @@ type MockTool struct { requiresClientAuthrorization bool } -func (t MockTool) Invoke(context.Context, parameters.ParamValues, tools.AccessToken) (any, error) { +func (t MockTool) Invoke(context.Context, tools.SourceProvider, parameters.ParamValues, tools.AccessToken) (any, error) { mock := []any{t.Name} return mock, nil } @@ -76,9 +77,9 @@ func (t MockTool) Authorized(verifiedAuthServices []string) bool { return !t.unauthorized } -func (t MockTool) RequiresClientAuthorization() bool { +func (t MockTool) RequiresClientAuthorization(tools.SourceProvider) (bool, error) { // defaulted to false - return t.requiresClientAuthrorization + return t.requiresClientAuthrorization, nil } func (t MockTool) McpManifest() tools.McpManifest { @@ -118,8 +119,8 @@ func (t MockTool) McpManifest() tools.McpManifest { return mcpManifest } -func (t MockTool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t MockTool) GetAuthTokenHeaderName(tools.SourceProvider) (string, error) { + return "Authorization", nil } // MockPrompt is used to mock prompts in tests @@ -275,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools sseManager := newSseManager(ctx) - resourceManager := NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets) + resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets) server := Server{ version: fakeVersionString, diff --git a/internal/server/config.go b/internal/server/config.go index e16d2e8327..fa0f1952a7 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -62,6 +62,8 @@ type ServerConfig struct { DisableReload bool // UI indicates if Toolbox UI endpoints (/ui) are available UI bool + // Specifies a list of origins permitted to access this server. + AllowedOrigins []string } type logFormat string diff --git a/internal/server/mcp.go b/internal/server/mcp.go index e6a58e9b88..aecd2454f2 100644 --- a/internal/server/mcp.go +++ b/internal/server/mcp.go @@ -205,10 +205,13 @@ func (s *stdioSession) readLine(ctx context.Context) (string, error) { } // write writes to stdout with response to client -func (s *stdioSession) write(ctx context.Context, response any) error { - res, _ := json.Marshal(response) +func (s *stdioSession) write(_ context.Context, response any) error { + res, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response to JSON: %w", err) + } - _, err := fmt.Fprintf(s.writer, "%s\n", res) + _, err = fmt.Fprintf(s.writer, "%s\n", res) return err } @@ -519,7 +522,7 @@ func processMcpMessage(ctx context.Context, body []byte, s *Server, protocolVers err = fmt.Errorf("promptset does not exist") return "", jsonrpc.NewError(baseMessage.Id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } - res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, s.ResourceMgr.GetToolsMap(), promptset, s.ResourceMgr.GetPromptsMap(), s.ResourceMgr.GetAuthServiceMap(), body, header) + res, err := mcp.ProcessMethod(ctx, protocolVersion, baseMessage.Id, baseMessage.Method, toolset, promptset, s.ResourceMgr, body, header) return "", res, err } } diff --git a/internal/server/mcp/mcp.go b/internal/server/mcp/mcp.go index 8dac5147d1..74ff1bee59 100644 --- a/internal/server/mcp/mcp.go +++ b/internal/server/mcp/mcp.go @@ -21,13 +21,13 @@ import ( "net/http" "slices" - "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" mcputil "github.com/googleapis/genai-toolbox/internal/server/mcp/util" v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105" v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326" v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" ) @@ -100,14 +100,14 @@ func NotificationHandler(ctx context.Context, body []byte) error { // ProcessMethod returns a response for the request. // This is the Operation phase of the lifecycle for MCP client-server connections. -func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, promptset prompts.Promptset, prompts map[string]prompts.Prompt, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { switch mcpVersion { case v20250618.PROTOCOL_VERSION: - return v20250618.ProcessMethod(ctx, id, method, toolset, tools, promptset, prompts, authServices, body, header) + return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header) case v20250326.PROTOCOL_VERSION: - return v20250326.ProcessMethod(ctx, id, method, toolset, tools, promptset, prompts, authServices, body, header) + return v20250326.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header) default: - return v20241105.ProcessMethod(ctx, id, method, toolset, tools, promptset, prompts, authServices, body, header) + return v20241105.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header) } } diff --git a/internal/server/mcp/v20241105/method.go b/internal/server/mcp/v20241105/method.go index 6934033460..0cbec0d1d2 100644 --- a/internal/server/mcp/v20241105/method.go +++ b/internal/server/mcp/v20241105/method.go @@ -23,26 +23,26 @@ import ( "net/http" "strings" - "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" ) // ProcessMethod returns a response for the request. -func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, promptset prompts.Promptset, prompts map[string]prompts.Prompt, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { switch method { case PING: return pingHandler(id) case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, tools, authServices, body, header) + return toolsCallHandler(ctx, id, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, prompts, body) + return promptsGetHandler(ctx, id, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -83,7 +83,9 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { + authServices := resourceMgr.GetAuthServiceMap() + // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -99,17 +101,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - tool, ok := toolsMap[toolName] + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization() { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -172,7 +184,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) // run tool invocation and generate response. - results, err := tool.Invoke(ctx, params, accessToken) + results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { errStr := err.Error() // Missing authService tokens. @@ -181,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization() { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } @@ -252,7 +264,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map[string]prompts.Prompt, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -268,7 +280,7 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) - prompt, ok := promptsMap[promptName] + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err diff --git a/internal/server/mcp/v20250326/method.go b/internal/server/mcp/v20250326/method.go index 4214a7cb39..a51bb161eb 100644 --- a/internal/server/mcp/v20250326/method.go +++ b/internal/server/mcp/v20250326/method.go @@ -23,26 +23,26 @@ import ( "net/http" "strings" - "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" ) // ProcessMethod returns a response for the request. -func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, promptset prompts.Promptset, prompts map[string]prompts.Prompt, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { switch method { case PING: return pingHandler(id) case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, tools, authServices, body, header) + return toolsCallHandler(ctx, id, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, prompts, body) + return promptsGetHandler(ctx, id, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -83,7 +83,9 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { + authServices := resourceMgr.GetAuthServiceMap() + // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -99,17 +101,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - tool, ok := toolsMap[toolName] + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization() { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -172,7 +184,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) // run tool invocation and generate response. - results, err := tool.Invoke(ctx, params, accessToken) + results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { errStr := err.Error() // Missing authService tokens. @@ -181,7 +193,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization() { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } @@ -251,7 +263,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map[string]prompts.Prompt, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -267,7 +279,7 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) - prompt, ok := promptsMap[promptName] + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err diff --git a/internal/server/mcp/v20250618/method.go b/internal/server/mcp/v20250618/method.go index 3ba18df908..ccfa5f102f 100644 --- a/internal/server/mcp/v20250618/method.go +++ b/internal/server/mcp/v20250618/method.go @@ -23,26 +23,26 @@ import ( "net/http" "strings" - "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" ) // ProcessMethod returns a response for the request. -func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, tools map[string]tools.Tool, promptset prompts.Promptset, prompts map[string]prompts.Prompt, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { switch method { case PING: return pingHandler(id) case TOOLS_LIST: return toolsListHandler(id, toolset, body) case TOOLS_CALL: - return toolsCallHandler(ctx, id, tools, authServices, body, header) + return toolsCallHandler(ctx, id, resourceMgr, body, header) case PROMPTS_LIST: return promptsListHandler(ctx, id, promptset, body) case PROMPTS_GET: - return promptsGetHandler(ctx, id, prompts, body) + return promptsGetHandler(ctx, id, resourceMgr, body) default: err := fmt.Errorf("invalid method %s", method) return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err @@ -76,7 +76,9 @@ func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) } // toolsCallHandler generate a response for tools call. -func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[string]tools.Tool, authServices map[string]auth.AuthService, body []byte, header http.Header) (any, error) { +func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) { + authServices := resourceMgr.GetAuthServiceMap() + // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -92,17 +94,27 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st toolName := req.Params.Name toolArgument := req.Params.Arguments logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName)) - tool, ok := toolsMap[toolName] + tool, ok := resourceMgr.GetTool(toolName) if !ok { err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err } // Get access token - accessToken := tools.AccessToken(header.Get(tool.GetAuthTokenHeaderName())) + authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + accessToken := tools.AccessToken(header.Get(authTokenHeadername)) // Check if this specific tool requires the standard authorization header - if tool.RequiresClientAuthorization() { + clientAuth, err := tool.RequiresClientAuthorization(resourceMgr) + if err != nil { + errMsg := fmt.Errorf("error during invocation: %w", err) + return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg + } + if clientAuth { if accessToken == "" { return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized } @@ -165,7 +177,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) // run tool invocation and generate response. - results, err := tool.Invoke(ctx, params, accessToken) + results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) if err != nil { errStr := err.Error() // Missing authService tokens. @@ -174,7 +186,7 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, toolsMap map[st } // Upstream auth error if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") { - if tool.RequiresClientAuthorization() { + if clientAuth { // Error with client credentials should pass down to the client return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err } @@ -244,7 +256,7 @@ func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset pro } // promptsGetHandler handles the "prompts/get" method. -func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map[string]prompts.Prompt, body []byte) (any, error) { +func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) { // retrieve logger from context logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -260,8 +272,7 @@ func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, promptsMap map promptName := req.Params.Name logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName)) - - prompt, ok := promptsMap[promptName] + prompt, ok := resourceMgr.GetPrompt(promptName) if !ok { err := fmt.Errorf("prompt with name %q does not exist", promptName) return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index 6f344b3caa..90b8676098 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -29,6 +29,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/telemetry" ) @@ -1106,7 +1107,7 @@ func TestStdioSession(t *testing.T) { sseManager := newSseManager(ctx) - resourceManager := NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets) + resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets) server := &Server{ version: fakeVersionString, diff --git a/internal/server/resources/resources.go b/internal/server/resources/resources.go new file mode 100644 index 0000000000..0cea0b7eaa --- /dev/null +++ b/internal/server/resources/resources.go @@ -0,0 +1,138 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resources + +import ( + "sync" + + "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager(). +type ResourceManager struct { + mu sync.RWMutex + sources map[string]sources.Source + authServices map[string]auth.AuthService + tools map[string]tools.Tool + toolsets map[string]tools.Toolset + prompts map[string]prompts.Prompt + promptsets map[string]prompts.Promptset +} + +func NewResourceManager( + sourcesMap map[string]sources.Source, + authServicesMap map[string]auth.AuthService, + toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, + promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset, + +) *ResourceManager { + resourceMgr := &ResourceManager{ + mu: sync.RWMutex{}, + sources: sourcesMap, + authServices: authServicesMap, + tools: toolsMap, + toolsets: toolsetsMap, + prompts: promptsMap, + promptsets: promptsetsMap, + } + + return resourceMgr +} + +func (r *ResourceManager) GetSource(sourceName string) (sources.Source, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + source, ok := r.sources[sourceName] + return source, ok +} + +func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthService, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + authService, ok := r.authServices[authServiceName] + return authService, ok +} + +func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + tool, ok := r.tools[toolName] + return tool, ok +} + +func (r *ResourceManager) GetToolset(toolsetName string) (tools.Toolset, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + toolset, ok := r.toolsets[toolsetName] + return toolset, ok +} + +func (r *ResourceManager) GetPrompt(promptName string) (prompts.Prompt, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + prompt, ok := r.prompts[promptName] + return prompt, ok +} + +func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + promptset, ok := r.promptsets[promptsetName] + return promptset, ok +} + +func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) { + r.mu.Lock() + defer r.mu.Unlock() + r.sources = sourcesMap + r.authServices = authServicesMap + r.tools = toolsMap + r.toolsets = toolsetsMap + r.prompts = promptsMap + r.promptsets = promptsetsMap +} + +func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]auth.AuthService, len(r.authServices)) + for k, v := range r.authServices { + copiedMap[k] = v + } + return copiedMap +} + +func (r *ResourceManager) GetToolsMap() map[string]tools.Tool { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]tools.Tool, len(r.tools)) + for k, v := range r.tools { + copiedMap[k] = v + } + return copiedMap +} + +func (r *ResourceManager) GetPromptsMap() map[string]prompts.Prompt { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]prompts.Prompt, len(r.prompts)) + for k, v := range r.prompts { + copiedMap[k] = v + } + return copiedMap +} diff --git a/internal/server/resources/resources_test.go b/internal/server/resources/resources_test.go new file mode 100644 index 0000000000..b746abf3fc --- /dev/null +++ b/internal/server/resources/resources_test.go @@ -0,0 +1,103 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package resources_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/server/resources" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +func TestUpdateServer(t *testing.T) { + newSources := map[string]sources.Source{ + "example-source": &alloydbpg.Source{ + Config: alloydbpg.Config{ + Name: "example-alloydb-source", + Kind: "alloydb-postgres", + }, + }, + } + newAuth := map[string]auth.AuthService{"example-auth": nil} + newTools := map[string]tools.Tool{"example-tool": nil} + newToolsets := map[string]tools.Toolset{ + "example-toolset": { + ToolsetConfig: tools.ToolsetConfig{ + Name: "example-toolset", + }, + Tools: []*tools.Tool{}, + }, + } + newPrompts := map[string]prompts.Prompt{"example-prompt": nil} + newPromptsets := map[string]prompts.Promptset{ + "example-promptset": { + PromptsetConfig: prompts.PromptsetConfig{ + Name: "example-promptset", + }, + Prompts: []*prompts.Prompt{}, + }, + } + resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets) + + gotSource, _ := resMgr.GetSource("example-source") + if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" { + t.Errorf("error updating server, sources (-want +got):\n%s", diff) + } + + gotAuthService, _ := resMgr.GetAuthService("example-auth") + if diff := cmp.Diff(gotAuthService, newAuth["example-auth"]); diff != "" { + t.Errorf("error updating server, authServices (-want +got):\n%s", diff) + } + + gotTool, _ := resMgr.GetTool("example-tool") + if diff := cmp.Diff(gotTool, newTools["example-tool"]); diff != "" { + t.Errorf("error updating server, tools (-want +got):\n%s", diff) + } + + gotToolset, _ := resMgr.GetToolset("example-toolset") + if diff := cmp.Diff(gotToolset, newToolsets["example-toolset"]); diff != "" { + t.Errorf("error updating server, toolset (-want +got):\n%s", diff) + } + + gotPrompt, _ := resMgr.GetPrompt("example-prompt") + if diff := cmp.Diff(gotPrompt, newPrompts["example-prompt"]); diff != "" { + t.Errorf("error updating server, prompts (-want +got):\n%s", diff) + } + + gotPromptset, _ := resMgr.GetPromptset("example-promptset") + if diff := cmp.Diff(gotPromptset, newPromptsets["example-promptset"]); diff != "" { + t.Errorf("error updating server, promptset (-want +got):\n%s", diff) + } + + updateSource := map[string]sources.Source{ + "example-source2": &alloydbpg.Source{ + Config: alloydbpg.Config{ + Name: "example-alloydb-source2", + Kind: "alloydb-postgres", + }, + }, + } + + resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets) + gotSource, _ = resMgr.GetSource("example-source2") + if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" { + t.Errorf("error updating server, sources (-want +got):\n%s", diff) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 8b5427d17c..4d2f600bd1 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -20,17 +20,19 @@ import ( "io" "net" "net/http" + "slices" "strconv" "strings" - "sync" "time" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/go-chi/cors" "github.com/go-chi/httplog/v2" "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" + "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/telemetry" "github.com/googleapis/genai-toolbox/internal/tools" @@ -48,109 +50,7 @@ type Server struct { logger log.Logger instrumentation *telemetry.Instrumentation sseManager *sseManager - ResourceMgr *ResourceManager -} - -// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager(). -type ResourceManager struct { - mu sync.RWMutex - sources map[string]sources.Source - authServices map[string]auth.AuthService - tools map[string]tools.Tool - toolsets map[string]tools.Toolset - prompts map[string]prompts.Prompt - promptsets map[string]prompts.Promptset -} - -func NewResourceManager( - sourcesMap map[string]sources.Source, - authServicesMap map[string]auth.AuthService, - toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, - promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset, - -) *ResourceManager { - resourceMgr := &ResourceManager{ - mu: sync.RWMutex{}, - sources: sourcesMap, - authServices: authServicesMap, - tools: toolsMap, - toolsets: toolsetsMap, - prompts: promptsMap, - promptsets: promptsetsMap, - } - - return resourceMgr -} - -func (r *ResourceManager) GetSource(sourceName string) (sources.Source, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - source, ok := r.sources[sourceName] - return source, ok -} - -func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthService, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - authService, ok := r.authServices[authServiceName] - return authService, ok -} - -func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - tool, ok := r.tools[toolName] - return tool, ok -} - -func (r *ResourceManager) GetToolset(toolsetName string) (tools.Toolset, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - toolset, ok := r.toolsets[toolsetName] - return toolset, ok -} - -func (r *ResourceManager) GetPrompt(promptName string) (prompts.Prompt, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - prompt, ok := r.prompts[promptName] - return prompt, ok -} - -func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset, bool) { - r.mu.RLock() - defer r.mu.RUnlock() - promptset, ok := r.promptsets[promptsetName] - return promptset, ok -} - -func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) { - r.mu.Lock() - defer r.mu.Unlock() - r.sources = sourcesMap - r.authServices = authServicesMap - r.tools = toolsMap - r.toolsets = toolsetsMap - r.prompts = promptsMap - r.promptsets = promptsetsMap -} - -func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService { - r.mu.RLock() - defer r.mu.RUnlock() - return r.authServices -} - -func (r *ResourceManager) GetToolsMap() map[string]tools.Tool { - r.mu.RLock() - defer r.mu.RUnlock() - return r.tools -} - -func (r *ResourceManager) GetPromptsMap() map[string]prompts.Prompt { - r.mu.RLock() - defer r.mu.RUnlock() - return r.prompts + ResourceMgr *resources.ResourceManager } func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( @@ -388,6 +288,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { // set up http serving r := chi.NewRouter() r.Use(middleware.Recoverer) + // logging logLevel, err := log.SeverityToLevel(cfg.LogLevel.String()) if err != nil { @@ -429,7 +330,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { sseManager := newSseManager(ctx) - resourceManager := NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) + resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) s := &Server{ version: cfg.Version, @@ -440,6 +341,21 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { sseManager: sseManager, ResourceMgr: resourceManager, } + + // cors + if slices.Contains(cfg.AllowedOrigins, "*") { + s.logger.WarnContext(ctx, "wildcard (`*`) allows all origin to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-origins` flag to prevent DNS rebinding attacks") + } + corsOpts := cors.Options{ + AllowedOrigins: cfg.AllowedOrigins, + AllowedMethods: []string{"GET", "POST", "DELETE", "OPTIONS"}, + AllowCredentials: true, // required since Toolbox uses auth headers + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "Mcp-Session-Id", "MCP-Protocol-Version"}, + ExposedHeaders: []string{"Mcp-Session-Id"}, // headers that are sent to clients + MaxAge: 300, // cache preflight results for 5 minutes + } + r.Use(cors.Handler(corsOpts)) + // control plane apiR, err := apiRouter(s) if err != nil { diff --git a/internal/sources/alloydbadmin/alloydbadmin.go b/internal/sources/alloydbadmin/alloydbadmin.go index d82126b2ea..633c7eb73e 100644 --- a/internal/sources/alloydbadmin/alloydbadmin.go +++ b/internal/sources/alloydbadmin/alloydbadmin.go @@ -15,8 +15,12 @@ package alloydbadmin import ( "context" + "encoding/json" "fmt" + "html/template" "net/http" + "strings" + "time" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" @@ -30,26 +34,6 @@ import ( const SourceKind string = "alloydb-admin" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -81,16 +65,13 @@ func (r Config) SourceConfigKind() string { func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { ua, err := util.UserAgentFromContext(ctx) if err != nil { - fmt.Printf("Error in User Agent retrieval: %s", err) + return nil, fmt.Errorf("error in User Agent retrieval: %s", err) } var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -99,10 +80,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } @@ -136,7 +114,11 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } -func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) { +func (s *Source) GetDefaultProject() string { + return s.DefaultProject +} + +func (s *Source) getService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) { if s.UseClientOAuth { token := &oauth2.Token{AccessToken: accessToken} client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) @@ -152,3 +134,287 @@ func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbre func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } + +func (s *Source) CreateCluster(ctx context.Context, project, location, network, user, password, cluster, accessToken string) (any, error) { + // Build the request body using the type-safe Cluster struct. + clusterBody := &alloydbrestapi.Cluster{ + NetworkConfig: &alloydbrestapi.NetworkConfig{ + Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network), + }, + InitialUser: &alloydbrestapi.UserPassword{ + User: user, + Password: password, + }, + } + + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + urlString := fmt.Sprintf("projects/%s/locations/%s", project, location) + + // The Create API returns a long-running operation. + resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(cluster).Do() + if err != nil { + return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err) + } + return resp, nil +} + +func (s *Source) CreateInstance(ctx context.Context, project, location, cluster, instanceID, instanceType, displayName string, nodeCount int, accessToken string) (any, error) { + // Build the request body using the type-safe Instance struct. + instance := &alloydbrestapi.Instance{ + InstanceType: instanceType, + NetworkConfig: &alloydbrestapi.InstanceNetworkConfig{ + EnablePublicIp: true, + }, + DatabaseFlags: map[string]string{ + "password.enforce_complexity": "on", + }, + } + + if displayName != "" { + instance.DisplayName = displayName + } + + if instanceType == "READ_POOL" { + instance.ReadPoolConfig = &alloydbrestapi.ReadPoolConfig{ + NodeCount: int64(nodeCount), + } + } + + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) + + // The Create API returns a long-running operation. + resp, err := service.Projects.Locations.Clusters.Instances.Create(urlString, instance).InstanceId(instanceID).Do() + if err != nil { + return nil, fmt.Errorf("error creating AlloyDB instance: %w", err) + } + return resp, nil +} + +func (s *Source) CreateUser(ctx context.Context, userType, password string, roles []string, accessToken, project, location, cluster, userID string) (any, error) { + // Build the request body using the type-safe User struct. + user := &alloydbrestapi.User{ + UserType: userType, + } + + if userType == "ALLOYDB_BUILT_IN" { + user.Password = password + } + + if len(roles) > 0 { + user.DatabaseRoles = roles + } + + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) + + // The Create API returns a long-running operation. + resp, err := service.Projects.Locations.Clusters.Users.Create(urlString, user).UserId(userID).Do() + if err != nil { + return nil, fmt.Errorf("error creating AlloyDB user: %w", err) + } + + return resp, nil +} + +func (s *Source) GetCluster(ctx context.Context, project, location, cluster, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) + + resp, err := service.Projects.Locations.Clusters.Get(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error getting AlloyDB cluster: %w", err) + } + + return resp, nil +} + +func (s *Source) GetInstance(ctx context.Context, project, location, cluster, instance, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, location, cluster, instance) + + resp, err := service.Projects.Locations.Clusters.Instances.Get(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error getting AlloyDB instance: %w", err) + } + return resp, nil +} + +func (s *Source) GetUsers(ctx context.Context, project, location, cluster, user, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", project, location, cluster, user) + + resp, err := service.Projects.Locations.Clusters.Users.Get(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error getting AlloyDB user: %w", err) + } + return resp, nil +} + +func (s *Source) ListCluster(ctx context.Context, project, location, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s", project, location) + + resp, err := service.Projects.Locations.Clusters.List(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error listing AlloyDB clusters: %w", err) + } + return resp, nil +} + +func (s *Source) ListInstance(ctx context.Context, project, location, cluster, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) + + resp, err := service.Projects.Locations.Clusters.Instances.List(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error listing AlloyDB instances: %w", err) + } + return resp, nil +} + +func (s *Source) ListUsers(ctx context.Context, project, location, cluster, accessToken string) (any, error) { + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) + + resp, err := service.Projects.Locations.Clusters.Users.List(urlString).Do() + if err != nil { + return nil, fmt.Errorf("error listing AlloyDB users: %w", err) + } + return resp, nil +} + +func (s *Source) GetOperations(ctx context.Context, project, location, operation, connectionMessageTemplate string, delay time.Duration, accessToken string) (any, error) { + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, err + } + + service, err := s.getService(ctx, accessToken) + if err != nil { + return nil, err + } + + name := fmt.Sprintf("projects/%s/locations/%s/operations/%s", project, location, operation) + + op, err := service.Projects.Locations.Operations.Get(name).Do() + if err != nil { + logger.DebugContext(ctx, fmt.Sprintf("error getting operation: %s, retrying in %v\n", err, delay)) + } else { + if op.Done { + if op.Error != nil { + var errorBytes []byte + errorBytes, err = json.Marshal(op.Error) + if err != nil { + return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err) + } + return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes)) + } + + var opBytes []byte + opBytes, err = op.MarshalJSON() + if err != nil { + return nil, fmt.Errorf("could not marshal operation: %w", err) + } + + if op.Response != nil { + var responseData map[string]any + if err := json.Unmarshal(op.Response, &responseData); err == nil && responseData != nil { + if msg, ok := generateAlloyDBConnectionMessage(responseData, connectionMessageTemplate); ok { + return msg, nil + } + } + } + + return string(opBytes), nil + } + logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay)) + } + return nil, nil +} + +func generateAlloyDBConnectionMessage(responseData map[string]any, connectionMessageTemplate string) (string, bool) { + resourceName, ok := responseData["name"].(string) + if !ok { + return "", false + } + + parts := strings.Split(resourceName, "/") + var project, region, cluster, instance string + + // Expected format: projects/{project}/locations/{location}/clusters/{cluster} + // or projects/{project}/locations/{location}/clusters/{cluster}/instances/{instance} + if len(parts) < 6 || parts[0] != "projects" || parts[2] != "locations" || parts[4] != "clusters" { + return "", false + } + + project = parts[1] + region = parts[3] + cluster = parts[5] + + if len(parts) >= 8 && parts[6] == "instances" { + instance = parts[7] + } else { + return "", false + } + + tmpl, err := template.New("alloydb-connection").Parse(connectionMessageTemplate) + if err != nil { + // This should not happen with a static template + return fmt.Sprintf("template parsing error: %v", err), false + } + + data := struct { + Project string + Region string + Cluster string + Instance string + }{ + Project: project, + Region: region, + Cluster: cluster, + Instance: instance, + } + + var b strings.Builder + if err := tmpl.Execute(&b, data); err != nil { + return fmt.Sprintf("template execution error: %v", err), false + } + + return b.String(), true +} diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index edf4310720..3adef5a051 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -101,6 +101,33 @@ func (s *Source) PostgresPool() *pgxpool.Pool { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.Pool.Query(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, statement, params) + } + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, f := range fields { + vMap[f.Name] = v[i] + } + out = append(out, vMap) + } + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + return out, nil +} + func getOpts(ipType, userAgent string, useIAM bool) ([]alloydbconn.Option, error) { opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)} switch strings.ToLower(ipType) { @@ -141,7 +168,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string // If password is provided without an username, raise an error return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") } - email, err := sources.GetIAMPrincipalEmailFromADC(ctx) + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres") if err != nil { return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) } diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go new file mode 100644 index 0000000000..a87ff11c59 --- /dev/null +++ b/internal/sources/cloudgda/cloud_gda.go @@ -0,0 +1,133 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package cloudgda + +import ( + "context" + "fmt" + "net/http" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "go.opentelemetry.io/otel/trace" + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +const SourceKind string = "cloud-gemini-data-analytics" +const Endpoint string = "https://geminidataanalytics.googleapis.com" + +// validate interface +var _ sources.SourceConfig = Config{} + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + ProjectID string `yaml:"projectId" validate:"required"` + UseClientOAuth bool `yaml:"useClientOAuth"` +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +// Initialize initializes a Gemini Data Analytics Source instance. +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + ua, err := util.UserAgentFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error in User Agent retrieval: %s", err) + } + + var client *http.Client + if r.UseClientOAuth { + client = &http.Client{ + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), + } + } else { + // Use Application Default Credentials + // Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA + creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to find default credentials: %w", err) + } + baseClient := oauth2.NewClient(ctx, creds.TokenSource) + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) + client = baseClient + } + + s := &Source{ + Config: r, + Client: client, + BaseURL: Endpoint, + userAgent: ua, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Config + Client *http.Client + BaseURL string + userAgent string +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) ToConfig() sources.SourceConfig { + return s.Config +} + +func (s *Source) GetProjectID() string { + return s.ProjectID +} + +func (s *Source) GetBaseURL() string { + return s.BaseURL +} + +func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { + if s.UseClientOAuth { + if accessToken == "" { + return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided") + } + token := &oauth2.Token{AccessToken: accessToken} + baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) + baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport) + return baseClient, nil + } + return s.Client, nil +} + +func (s *Source) UseClientAuthorization() bool { + return s.UseClientOAuth +} diff --git a/internal/sources/cloudgda/cloud_gda_test.go b/internal/sources/cloudgda/cloud_gda_test.go new file mode 100644 index 0000000000..30b977729d --- /dev/null +++ b/internal/sources/cloudgda/cloud_gda_test.go @@ -0,0 +1,213 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "context" + "os" + "path/filepath" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/testutils" + "go.opentelemetry.io/otel/trace/noop" +) + +func TestParseFromYamlCloudGDA(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + projectId: test-project-id + `, + want: map[string]sources.SourceConfig{ + "my-gda-instance": cloudgda.Config{ + Name: "my-gda-instance", + Kind: cloudgda.SourceKind, + ProjectID: "test-project-id", + UseClientOAuth: false, + }, + }, + }, + { + desc: "use client auth example", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + projectId: another-project + useClientOAuth: true + `, + want: map[string]sources.SourceConfig{ + "my-gda-instance": cloudgda.Config{ + Name: "my-gda-instance", + Kind: cloudgda.SourceKind, + ProjectID: "another-project", + UseClientOAuth: true, + }, + }, + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } +} + +func TestFailParseFromYaml(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "missing projectId", + in: ` + sources: + my-gda-instance: + kind: cloud-gemini-data-analytics + `, + err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag", + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } +} + +func TestInitialize(t *testing.T) { + // Create a dummy credentials file for testing ADC + credFile := filepath.Join(t.TempDir(), "application_default_credentials.json") + dummyCreds := `{ + "client_id": "foo", + "client_secret": "bar", + "refresh_token": "baz", + "type": "authorized_user" + }` + if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil { + t.Fatalf("failed to write dummy credentials file: %v", err) + } + t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile) + + // Use ContextWithUserAgent to avoid "unable to retrieve user agent" error + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + tracer := noop.NewTracerProvider().Tracer("test") + + tcs := []struct { + desc string + cfg cloudgda.Config + wantClientOAuth bool + }{ + { + desc: "initialize with ADC", + cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"}, + wantClientOAuth: false, + }, + { + desc: "initialize with client OAuth", + cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true}, + wantClientOAuth: true, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + src, err := tc.cfg.Initialize(ctx, tracer) + if err != nil { + t.Fatalf("failed to initialize source: %v", err) + } + + gdaSrc, ok := src.(*cloudgda.Source) + if !ok { + t.Fatalf("expected *cloudgda.Source, got %T", src) + } + + // Check that the client is non-nil + if gdaSrc.Client == nil && !tc.wantClientOAuth { + t.Fatal("expected non-nil HTTP client for ADC, got nil") + } + // When client OAuth is true, the source's client should be initialized with a base HTTP client + // that includes the user agent round tripper, but not the OAuth token. The token-aware + // client is created by GetClient. + if gdaSrc.Client == nil && tc.wantClientOAuth { + t.Fatal("expected non-nil HTTP client for client OAuth config, got nil") + } + + // Test UseClientAuthorization method + if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth { + t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization()) + } + + // Test GetClient with accessToken for client OAuth scenarios + if tc.wantClientOAuth { + client, err := gdaSrc.GetClient(ctx, "dummy-token") + if err != nil { + t.Fatalf("GetClient with token failed: %v", err) + } + if client == nil { + t.Fatal("expected non-nil HTTP client from GetClient with token, got nil") + } + // Ensure passing empty token with UseClientOAuth enabled returns error + _, err = gdaSrc.GetClient(ctx, "") + if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" { + t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err) + } + } + }) + } +} diff --git a/internal/sources/cloudmonitoring/cloud_monitoring.go b/internal/sources/cloudmonitoring/cloud_monitoring.go index 4c6db77ed1..d43468687d 100644 --- a/internal/sources/cloudmonitoring/cloud_monitoring.go +++ b/internal/sources/cloudmonitoring/cloud_monitoring.go @@ -29,26 +29,6 @@ import ( const SourceKind string = "cloud-monitoring" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -86,10 +66,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -98,18 +75,15 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } s := &Source{ Config: r, - BaseURL: "https://monitoring.googleapis.com", - Client: client, - UserAgent: ua, + baseURL: "https://monitoring.googleapis.com", + client: client, + userAgent: ua, } return s, nil } @@ -118,9 +92,9 @@ var _ sources.Source = &Source{} type Source struct { Config - BaseURL string `yaml:"baseUrl"` - Client *http.Client - UserAgent string + baseURL string + client *http.Client + userAgent string } func (s *Source) SourceKind() string { @@ -131,6 +105,18 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) BaseURL() string { + return s.baseURL +} + +func (s *Source) Client() *http.Client { + return s.client +} + +func (s *Source) UserAgent() string { + return s.userAgent +} + func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) { if s.UseClientOAuth { if accessToken == "" { @@ -139,7 +125,7 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien token := &oauth2.Token{AccessToken: accessToken} return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil } - return s.Client, nil + return s.client, nil } func (s *Source) UseClientAuthorization() bool { diff --git a/internal/sources/cloudsqladmin/cloud_sql_admin.go b/internal/sources/cloudsqladmin/cloud_sql_admin.go index e0827faf9d..3a3ff48caf 100644 --- a/internal/sources/cloudsqladmin/cloud_sql_admin.go +++ b/internal/sources/cloudsqladmin/cloud_sql_admin.go @@ -30,26 +30,6 @@ import ( const SourceKind string = "cloud-sql-admin" -type userAgentRoundTripper struct { - userAgent string - next http.RoundTripper -} - -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - newReq := *req - newReq.Header = make(http.Header) - for k, v := range req.Header { - newReq.Header[k] = v - } - ua := newReq.Header.Get("User-Agent") - if ua == "" { - newReq.Header.Set("User-Agent", rt.userAgent) - } else { - newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) - } - return rt.next.RoundTrip(&newReq) -} - // validate interface var _ sources.SourceConfig = Config{} @@ -88,10 +68,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So var client *http.Client if r.UseClientOAuth { client = &http.Client{ - Transport: &userAgentRoundTripper{ - userAgent: ua, - next: http.DefaultTransport, - }, + Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport), } } else { // Use Application Default Credentials @@ -100,10 +77,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So return nil, fmt.Errorf("failed to find default credentials: %w", err) } baseClient := oauth2.NewClient(ctx, creds.TokenSource) - baseClient.Transport = &userAgentRoundTripper{ - userAgent: ua, - next: baseClient.Transport, - } + baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport) client = baseClient } @@ -136,6 +110,10 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) GetDefaultProject() string { + return s.DefaultProject +} + func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) { if s.UseClientOAuth { token := &oauth2.Token{AccessToken: accessToken} diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 4bdee7f3a0..797985454b 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -54,8 +54,8 @@ type Config struct { Region string `yaml:"region" validate:"required"` Instance string `yaml:"instance" validate:"required"` IPType sources.IPType `yaml:"ipType"` - User string `yaml:"user" validate:"required"` - Password string `yaml:"password" validate:"required"` + User string `yaml:"user"` + Password string `yaml:"password"` Database string `yaml:"database" validate:"required"` } @@ -100,31 +100,89 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) { + useIAM := true + + // If username and password both provided, use password authentication + if user != "" && pass != "" { + useIAM = false + return user, pass, useIAM, nil + } + + // If username is empty, fetch email from ADC + // otherwise, use username as IAM email + if user == "" { + if pass != "" { + return "", "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") + } + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "mysql") + if err != nil { + return "", "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) + } + user = email + } + + // Pass the user, empty password and useIAM set to true + return user, pass, useIAM, nil +} + func initCloudSQLMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() + // Configure the driver to connect to the database + user, pass, useIAM, err := getConnectionConfig(ctx, user, pass) + if err != nil { + return nil, fmt.Errorf("unable to get Cloud SQL connection config: %w", err) + } + // Create a new dialer with options userAgent, err := util.UserAgentFromContext(ctx) if err != nil { return nil, err } - opts, err := sources.GetCloudSQLOpts(ipType, userAgent, false) + opts, err := sources.GetCloudSQLOpts(ipType, userAgent, useIAM) if err != nil { return nil, err } - if !slices.Contains(sql.Drivers(), "cloudsql-mysql") { - _, err = mysql.RegisterDriver("cloudsql-mysql", opts...) - if err != nil { + // Use a unique driver name based on the source name. + driverName := fmt.Sprintf("cloudsql-mysql-%s", name) + + if !slices.Contains(sql.Drivers(), driverName) { + if _, err := mysql.RegisterDriver(driverName, opts...); err != nil { return nil, fmt.Errorf("unable to register driver: %w", err) } } + + var dsn string // Tell the driver to use the Cloud SQL Go Connector to create connections - dsn := fmt.Sprintf("%s:%s@cloudsql-mysql(%s:%s:%s)/%s?connectionAttributes=program_name:%s", user, pass, project, region, instance, dbname, url.QueryEscape(userAgent)) + if useIAM { + dsn = fmt.Sprintf("%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s", + user, + driverName, + project, + region, + instance, + dbname, + url.QueryEscape(userAgent), + ) + } else { + dsn = fmt.Sprintf("%s:%s@%s(%s:%s:%s)/%s?connectionAttributes=program_name:%s", + user, + pass, + driverName, + project, + region, + instance, + dbname, + url.QueryEscape(userAgent), + ) + } + db, err := sql.Open( - "cloudsql-mysql", + driverName, dsn, ) if err != nil { diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index f13c67d5d0..3de83993bb 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -120,7 +120,7 @@ func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string // If password is provided without an username, raise an error return "", useIAM, fmt.Errorf("password is provided without a username. Please provide both a username and password, or leave both fields empty") } - email, err := sources.GetIAMPrincipalEmailFromADC(ctx) + email, err := sources.GetIAMPrincipalEmailFromADC(ctx, "postgres") if err != nil { return "", useIAM, fmt.Errorf("error getting email from ADC: %v", err) } diff --git a/internal/sources/firebird/firebird.go b/internal/sources/firebird/firebird.go index 5f4a5dfa2d..43775be70c 100644 --- a/internal/sources/firebird/firebird.go +++ b/internal/sources/firebird/firebird.go @@ -21,8 +21,10 @@ import ( "time" "github.com/goccy/go-yaml" - "github.com/googleapis/genai-toolbox/internal/sources" + _ "github.com/nakagami/firebirdsql" "go.opentelemetry.io/otel/trace" + + "github.com/googleapis/genai-toolbox/internal/sources" ) const SourceKind string = "firebird" diff --git a/internal/sources/http/http.go b/internal/sources/http/http.go index 8f51e84114..b4e9fdd937 100644 --- a/internal/sources/http/http.go +++ b/internal/sources/http/http.go @@ -107,7 +107,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So s := &Source{ Config: r, - Client: &client, + client: &client, } return s, nil @@ -117,7 +117,7 @@ var _ sources.Source = &Source{} type Source struct { Config - Client *http.Client + client *http.Client } func (s *Source) SourceKind() string { @@ -127,3 +127,19 @@ func (s *Source) SourceKind() string { func (s *Source) ToConfig() sources.SourceConfig { return s.Config } + +func (s *Source) HttpDefaultHeaders() map[string]string { + return s.DefaultHeaders +} + +func (s *Source) HttpBaseURL() string { + return s.BaseURL +} + +func (s *Source) HttpQueryParams() map[string]string { + return s.QueryParams +} + +func (s *Source) Client() *http.Client { + return s.client +} diff --git a/internal/sources/looker/looker.go b/internal/sources/looker/looker.go index d88883a7ad..3b60127a55 100644 --- a/internal/sources/looker/looker.go +++ b/internal/sources/looker/looker.go @@ -160,10 +160,6 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } -func (s *Source) GetApiSettings() *rtl.ApiSettings { - return s.ApiSettings -} - func (s *Source) UseClientAuthorization() bool { return strings.ToLower(s.UseClientOAuth) != "false" } @@ -188,6 +184,30 @@ func (s *Source) GoogleCloudTokenSourceWithScope(ctx context.Context, scope stri return google.DefaultTokenSource(ctx, scope) } +func (s *Source) LookerClient() *v4.LookerSDK { + return s.Client +} + +func (s *Source) LookerApiSettings() *rtl.ApiSettings { + return s.ApiSettings +} + +func (s *Source) LookerShowHiddenFields() bool { + return s.ShowHiddenFields +} + +func (s *Source) LookerShowHiddenModels() bool { + return s.ShowHiddenModels +} + +func (s *Source) LookerShowHiddenExplores() bool { + return s.ShowHiddenExplores +} + +func (s *Source) LookerSessionLength() int64 { + return s.SessionLength +} + func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) { cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...) if err != nil { diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 3b37560004..4de64b402b 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -9,9 +9,11 @@ import ( "strings" "github.com/goccy/go-yaml" + _ "github.com/godror/godror" // OCI driver + _ "github.com/sijms/go-ora/v2" // Pure Go driver + "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" - _ "github.com/sijms/go-ora/v2" "go.opentelemetry.io/otel/trace" ) @@ -32,7 +34,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources return nil, err } - // Validate that we have one of: tns_alias, connection_string, or host+service_name + // Validate that we have one of: tnsAlias, connectionString, or host+service_name if err := actual.validate(); err != nil { return nil, fmt.Errorf("invalid Oracle configuration: %w", err) } @@ -43,21 +45,24 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` - ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename) - TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora - Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias - Port int `yaml:"port,omitempty"` // Explicit port support - ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias + ConnectionString string `yaml:"connectionString,omitempty"` + TnsAlias string `yaml:"tnsAlias,omitempty"` + TnsAdmin string `yaml:"tnsAdmin,omitempty"` + Host string `yaml:"host,omitempty"` + Port int `yaml:"port,omitempty"` + ServiceName string `yaml:"serviceName,omitempty"` User string `yaml:"user" validate:"required"` Password string `yaml:"password" validate:"required"` - TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable + UseOCI bool `yaml:"useOCI,omitempty"` + WalletLocation string `yaml:"walletLocation,omitempty"` } -// validate ensures we have one of: tns_alias, connection_string, or host+service_name func (c Config) validate() error { + hasTnsAdmin := strings.TrimSpace(c.TnsAdmin) != "" hasTnsAlias := strings.TrimSpace(c.TnsAlias) != "" hasConnStr := strings.TrimSpace(c.ConnectionString) != "" hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != "" + hasWallet := strings.TrimSpace(c.WalletLocation) != "" connectionMethods := 0 if hasTnsAlias { @@ -78,6 +83,14 @@ func (c Config) validate() error { return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'") } + if hasTnsAdmin && !c.UseOCI { + return fmt.Errorf("`tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead") + } + + if hasWallet && c.UseOCI { + return fmt.Errorf("when using an OCI driver, use `tnsAdmin` to specify credentials file location instead") + } + return nil } @@ -132,7 +145,8 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi panic(err) } - // Set TNS_ADMIN environment variable if specified in config. + hasWallet := strings.TrimSpace(config.WalletLocation) != "" + if config.TnsAdmin != "" { originalTnsAdmin := os.Getenv("TNS_ADMIN") os.Setenv("TNS_ADMIN", config.TnsAdmin) @@ -147,28 +161,49 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi }() } - var serverString string + var connectStringBase string if config.TnsAlias != "" { - // Use TNS alias - serverString = strings.TrimSpace(config.TnsAlias) + connectStringBase = strings.TrimSpace(config.TnsAlias) } else if config.ConnectionString != "" { - // Use provided connection string directly (hostname[:port]/servicename format) - serverString = strings.TrimSpace(config.ConnectionString) + connectStringBase = strings.TrimSpace(config.ConnectionString) } else { - // Build connection string from host and service_name if config.Port > 0 { - serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) + connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName) } else { - serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) + connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName) } } - connStr := fmt.Sprintf("oracle://%s:%s@%s", - config.User, config.Password, serverString) + var driverName string + var finalConnStr string - db, err := sql.Open("oracle", connStr) + if config.UseOCI { + // Use godror driver (requires OCI) + driverName = "godror" + finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase)) + } else { + // Use go-ora driver (pure Go) + driverName = "oracle" + + user := config.User + password := config.Password + + if hasWallet { + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s", + user, password, connectStringBase, config.WalletLocation) + } else { + // Standard go-ora connection + finalConnStr = fmt.Sprintf("oracle://%s:%s@%s", + config.User, config.Password, connectStringBase) + logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase)) + } + } + + db, err := sql.Open(driverName, finalConnStr) if err != nil { - return nil, fmt.Errorf("unable to open Oracle connection: %w", err) + return nil, fmt.Errorf("unable to open Oracle connection with driver %s: %w", driverName, err) } return db, nil diff --git a/internal/sources/oracle/oracle_test.go b/internal/sources/oracle/oracle_test.go new file mode 100644 index 0000000000..3d8f4c7ba5 --- /dev/null +++ b/internal/sources/oracle/oracle_test.go @@ -0,0 +1,200 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracle_test + +import ( + "strings" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/oracle" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "connection string and useOCI=true", + in: ` + sources: + my-oracle-cs: + kind: oracle + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-cs": oracle.Config{ + Name: "my-oracle-cs", + Kind: oracle.SourceKind, + ConnectionString: "my-host:1521/XEPDB1", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + { + desc: "host/port/serviceName and default useOCI=false", + in: ` + sources: + my-oracle-host: + kind: oracle + host: my-host + port: 1521 + serviceName: ORCLPDB + user: my_user + password: my_pass + `, + want: server.SourceConfigs{ + "my-oracle-host": oracle.Config{ + Name: "my-oracle-host", + Kind: oracle.SourceKind, + Host: "my-host", + Port: 1521, + ServiceName: "ORCLPDB", + User: "my_user", + Password: "my_pass", + UseOCI: false, + }, + }, + }, + { + desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true", + in: ` + sources: + my-oracle-tns-oci: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: true + `, + want: server.SourceConfigs{ + "my-oracle-tns-oci": oracle.Config{ + Name: "my-oracle-tns-oci", + Kind: oracle.SourceKind, + TnsAlias: "FINANCE_DB", + TnsAdmin: "/opt/oracle/network/admin", + User: "my_user", + Password: "my_pass", + UseOCI: true, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources)) + } + }) + } +} + +func TestFailParseFromYamlOracle(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + password: my_pass + extraField: value + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ", + }, + { + desc: "missing required password field", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + user: my_user + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag", + }, + { + desc: "missing connection method fields (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'", + }, + { + desc: "multiple connection methods provided (validate fails)", + in: ` + sources: + my-oracle-instance: + kind: oracle + host: my-host + serviceName: ORCL + connectionString: "my-host:1521/XEPDB1" + user: my_user + password: my_pass + `, + err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'", + }, + { + desc: "fail on tnsAdmin with useOCI=false", + in: ` + sources: + my-oracle-fail: + kind: oracle + tnsAlias: FINANCE_DB + tnsAdmin: /opt/oracle/network/admin + user: my_user + password: my_pass + useOCI: false + `, + err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := strings.ReplaceAll(err.Error(), "\r", "") + + if errStr != tc.err { + t.Fatalf("unexpected error:\ngot:\n%q\nwant:\n%q\n", errStr, tc.err) + } + }) + } +} diff --git a/internal/sources/serverlessspark/serverlessspark.go b/internal/sources/serverlessspark/serverlessspark.go index 2e95199ecd..c63adb6863 100644 --- a/internal/sources/serverlessspark/serverlessspark.go +++ b/internal/sources/serverlessspark/serverlessspark.go @@ -96,6 +96,14 @@ func (s *Source) ToConfig() sources.SourceConfig { return s.Config } +func (s *Source) GetProject() string { + return s.Project +} + +func (s *Source) GetLocation() string { + return s.Location +} + func (s *Source) GetBatchControllerClient() *dataproc.BatchControllerClient { return s.Client } diff --git a/internal/sources/util.go b/internal/sources/util.go index 0a78c1b2a6..d2b2210ddd 100644 --- a/internal/sources/util.go +++ b/internal/sources/util.go @@ -48,7 +48,7 @@ func GetCloudSQLOpts(ipType, userAgent string, useIAM bool) ([]cloudsqlconn.Opti } // GetIAMPrincipalEmailFromADC finds the email associated with ADC -func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) { +func GetIAMPrincipalEmailFromADC(ctx context.Context, dbType string) (string, error) { // Finds ADC and returns an HTTP client associated with it client, err := google.DefaultClient(ctx, "https://www.googleapis.com/auth/userinfo.email") @@ -83,9 +83,31 @@ func GetIAMPrincipalEmailFromADC(ctx context.Context) (string, error) { if !ok { return "", fmt.Errorf("email not found in response: %v", err) } - // service account email used for IAM should trim the suffix - email := strings.TrimSuffix(emailValue.(string), ".gserviceaccount.com") - return email, nil + + fullEmail, ok := emailValue.(string) + if !ok { + return "", fmt.Errorf("email field is not a string") + } + + var username string + // Format the username based on Database Type + switch strings.ToLower(dbType) { + case "mysql": + username, _, _ = strings.Cut(fullEmail, "@") + + case "postgres": + // service account email used for IAM should trim the suffix + username = strings.TrimSuffix(fullEmail, ".gserviceaccount.com") + + default: + return "", fmt.Errorf("unsupported dbType: %s. Use 'mysql' or 'postgres'", dbType) + } + + if username == "" { + return "", fmt.Errorf("username from ADC cannot be an empty string") + } + + return username, nil } func GetIAMAccessToken(ctx context.Context) (string, error) { diff --git a/internal/testutils/testutils.go b/internal/testutils/testutils.go index f6c30b859c..82975fb321 100644 --- a/internal/testutils/testutils.go +++ b/internal/testutils/testutils.go @@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) { return util.WithLogger(ctx, logger), nil } +// ContextWithUserAgent creates a new context with a specified user agent string. +func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context { + return util.WithUserAgent(ctx, userAgent) +} + // WaitForString waits until the server logs a single line that matches the provided regex. // returns the output of whatever the server sent so far. func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) { diff --git a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go index 4b5a1f5efd..0702f6388b 100644 --- a/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go +++ b/internal/tools/alloydb/alloydbcreatecluster/alloydbcreatecluster.go @@ -20,10 +20,8 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-create-cluster" @@ -42,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + CreateCluster(context.Context, string, string, string, string, string, string, string) (any, error) +} + // Configuration for the create-cluster tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -97,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -107,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-cluster tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest @@ -119,7 +121,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -151,31 +158,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'user' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s", project, location) - - // Build the request body using the type-safe Cluster struct. - clusterBody := &alloydb.Cluster{ - NetworkConfig: &alloydb.NetworkConfig{ - Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network), - }, - InitialUser: &alloydb.UserPassword{ - User: user, - Password: password, - }, - } - - // The Create API returns a long-running operation. - resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(clusterID).Do() - if err != nil { - return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err) - } - - return resp, nil + return source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -198,10 +181,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go index b401278148..6a0aefa4ec 100644 --- a/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go +++ b/internal/tools/alloydb/alloydbcreateinstance/alloydbcreateinstance.go @@ -20,10 +20,8 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-create-instance" @@ -42,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + CreateInstance(context.Context, string, string, string, string, string, string, int, string) (any, error) +} + // Configuration for the create-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -98,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -108,7 +111,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instance tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest @@ -120,7 +122,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -147,45 +154,17 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) - - // Build the request body using the type-safe Instance struct. - instance := &alloydb.Instance{ - InstanceType: instanceType, - NetworkConfig: &alloydb.InstanceNetworkConfig{ - EnablePublicIp: true, - }, - DatabaseFlags: map[string]string{ - "password.enforce_complexity": "on", - }, - } - - if displayName, ok := paramsMap["displayName"].(string); ok && displayName != "" { - instance.DisplayName = displayName - } + displayName, _ := paramsMap["displayName"].(string) + var nodeCount int if instanceType == "READ_POOL" { - nodeCount, ok := paramsMap["nodeCount"].(int) + nodeCount, ok = paramsMap["nodeCount"].(int) if !ok { return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL") } - instance.ReadPoolConfig = &alloydb.ReadPoolConfig{ - NodeCount: int64(nodeCount), - } } - // The Create API returns a long-running operation. - resp, err := service.Projects.Locations.Clusters.Instances.Create(urlString, instance).InstanceId(instanceID).Do() - if err != nil { - return nil, fmt.Errorf("error creating AlloyDB instance: %w", err) - } - - return resp, nil + return source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -208,10 +187,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go index 82a15e9943..8378a2af45 100644 --- a/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go +++ b/internal/tools/alloydb/alloydbcreateuser/alloydbcreateuser.go @@ -20,10 +20,8 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/alloydb/v1" ) const kind string = "alloydb-create-user" @@ -42,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + CreateUser(context.Context, string, string, []string, string, string, string, string, string) (any, error) +} + // Configuration for the create-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -98,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -108,9 +111,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-user tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -120,7 +121,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) if !ok || project == "" { @@ -146,46 +152,24 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") { return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'") } - - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) - - // Build the request body using the type-safe User struct. - user := &alloydb.User{ - UserType: userType, - } + var password string if userType == "ALLOYDB_BUILT_IN" { - password, ok := paramsMap["password"].(string) + password, ok = paramsMap["password"].(string) if !ok || password == "" { return nil, fmt.Errorf("password is required when userType is ALLOYDB_BUILT_IN") } - user.Password = password } + var roles []string if dbRolesRaw, ok := paramsMap["databaseRoles"].([]any); ok && len(dbRolesRaw) > 0 { - var roles []string for _, r := range dbRolesRaw { if role, ok := r.(string); ok { roles = append(roles, role) } } - if len(roles) > 0 { - user.DatabaseRoles = roles - } } - - // The Create API returns a long-running operation. - resp, err := service.Projects.Locations.Clusters.Users.Create(urlString, user).UserId(userID).Do() - if err != nil { - return nil, fmt.Errorf("error creating AlloyDB user: %w", err) - } - - return resp, nil + return source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID) } // ParseParams parses the parameters for the tool. @@ -208,10 +192,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go index f4fa519f9c..2d12579de4 100644 --- a/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go +++ b/internal/tools/alloydb/alloydbgetcluster/alloydbgetcluster.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetCluster(context.Context, string, string, string, string) (any, error) +} + // Configuration for the get-cluster tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,7 +108,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-cluster tool. type Tool struct { Config - Source *alloydbadmin.Source AllParams parameters.Parameters manifest tools.Manifest @@ -116,7 +119,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,19 +140,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) - - resp, err := service.Projects.Locations.Clusters.Get(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error getting AlloyDB cluster: %w", err) - } - - return resp, nil + return source.GetCluster(ctx, project, location, cluster, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -167,10 +163,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go index c62e798058..9b76b9b9b5 100644 --- a/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go +++ b/internal/tools/alloydb/alloydbgetinstance/alloydbgetinstance.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetInstance(context.Context, string, string, string, string, string) (any, error) +} + // Configuration for the get-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-instance tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters - + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,7 +119,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -137,19 +144,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'instance' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, location, cluster, instance) - - resp, err := service.Projects.Locations.Clusters.Instances.Get(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error getting AlloyDB instance: %w", err) - } - - return resp, nil + return source.GetInstance(ctx, project, location, cluster, instance, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -172,10 +167,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go index ab057edb9e..a56da8dbda 100644 --- a/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go +++ b/internal/tools/alloydb/alloydbgetuser/alloydbgetuser.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetUsers(context.Context, string, string, string, string, string) (any, error) +} + // Configuration for the get-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,9 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-user tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters - + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,7 +119,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -137,19 +144,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'user' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", project, location, cluster, user) - - resp, err := service.Projects.Locations.Clusters.Users.Get(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error getting AlloyDB user: %w", err) - } - - return resp, nil + return source.GetUsers(ctx, project, location, cluster, user, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -172,10 +167,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go index 83f2d7bea9..f408dbeda6 100644 --- a/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go +++ b/internal/tools/alloydb/alloydblistclusters/alloydblistclusters.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + ListCluster(context.Context, string, string, string) (any, error) +} + // Configuration for the list-clusters tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -103,9 +107,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-clusters tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -115,7 +117,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -127,19 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'location' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s", project, location) - - resp, err := service.Projects.Locations.Clusters.List(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error listing AlloyDB clusters: %w", err) - } - - return resp, nil + return source.ListCluster(ctx, project, location, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -162,10 +157,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go index ff73c28478..b355d055a2 100644 --- a/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go +++ b/internal/tools/alloydb/alloydblistinstances/alloydblistinstances.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + ListInstance(context.Context, string, string, string, string) (any, error) +} + // Configuration for the list-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-instances tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -116,7 +118,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,19 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) - - resp, err := service.Projects.Locations.Clusters.Instances.List(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error listing AlloyDB instances: %w", err) - } - - return resp, nil + return source.ListInstance(ctx, project, location, cluster, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -167,10 +162,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go index 3b4d09b2ff..4148c75972 100644 --- a/internal/tools/alloydb/alloydblistusers/alloydblistusers.go +++ b/internal/tools/alloydb/alloydblistusers/alloydblistusers.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -41,6 +40,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + ListUsers(context.Context, string, string, string, string) (any, error) +} + // Configuration for the list-users tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("source %q not found", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -94,7 +99,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -104,9 +108,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the list-users tool. type Tool struct { Config - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` - + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -116,7 +118,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -132,19 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster) - - resp, err := service.Projects.Locations.Clusters.Users.List(urlString).Do() - if err != nil { - return nil, fmt.Errorf("error listing AlloyDB users: %w", err) - } - - return resp, nil + return source.ListUsers(ctx, project, location, cluster, string(accessToken)) } // ParseParams parses the parameters for the tool. @@ -167,10 +162,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go index 6ca877c540..18fc0fe6c6 100644 --- a/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go +++ b/internal/tools/alloydb/alloydbwaitforoperation/alloydbwaitforoperation.go @@ -16,16 +16,12 @@ package alloydbwaitforoperation import ( "context" - "encoding/json" "fmt" "net/http" - "strings" - "text/template" "time" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -50,8 +46,8 @@ Update the MCP server configuration with the following environment variables: "ALLOYDB_POSTGRES_CLUSTER": "{{.Cluster}}", {{if .Instance}} "ALLOYDB_POSTGRES_INSTANCE": "{{.Instance}}", {{end}} "ALLOYDB_POSTGRES_DATABASE": "postgres", - "ALLOYDB_POSTGRES_USER": ""{{.User}}",", - "ALLOYDB_POSTGRES_PASSWORD": ""{{.Password}}", + "ALLOYDB_POSTGRES_USER": "", + "ALLOYDB_POSTGRES_PASSWORD": "" } } } @@ -89,6 +85,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + UseClientAuthorization() bool + GetOperations(context.Context, string, string, string, string, time.Duration, string) (any, error) +} + // Config defines the configuration for the wait-for-operation tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -119,12 +121,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*alloydbadmin.Source) + s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, alloydbadmin.SourceKind) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -180,7 +182,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -194,19 +195,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the wait-for-operation tool. type Tool struct { Config - - Source *alloydbadmin.Source - AllParams parameters.Parameters `yaml:"allParams"` + AllParams parameters.Parameters `yaml:"allParams"` + Client *http.Client + manifest tools.Manifest + mcpManifest tools.McpManifest // Polling configuration Delay time.Duration MaxDelay time.Duration Multiplier float64 MaxRetries int - - Client *http.Client - manifest tools.Manifest - mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -214,7 +212,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -230,16 +233,9 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing 'operation' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) defer cancel() - name := fmt.Sprintf("projects/%s/locations/%s/operations/%s", project, location, operation) - delay := t.Delay maxDelay := t.MaxDelay multiplier := t.Multiplier @@ -253,33 +249,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT default: } - op, err := service.Projects.Locations.Operations.Get(name).Do() + op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken)) if err != nil { - fmt.Printf("error getting operation: %s, retrying in %v\n", err, delay) - } else { - if op.Done { - if op.Error != nil { - var errorBytes []byte - errorBytes, err = json.Marshal(op.Error) - if err != nil { - return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err) - } - return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes)) - } - - var opBytes []byte - opBytes, err = op.MarshalJSON() - if err != nil { - return nil, fmt.Errorf("could not marshal operation: %w", err) - } - - if msg, ok := t.generateAlloyDBConnectionMessage(map[string]any{"response": op.Response}); ok { - return msg, nil - } - - return string(opBytes), nil - } - fmt.Printf("Operation not complete, retrying in %v\n", delay) + return nil, err + } else if op != nil { + return op, nil } time.Sleep(delay) @@ -292,57 +266,6 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("exceeded max retries waiting for operation") } -func (t Tool) generateAlloyDBConnectionMessage(responseData map[string]any) (string, bool) { - resourceName, ok := responseData["name"].(string) - if !ok { - return "", false - } - - parts := strings.Split(resourceName, "/") - var project, region, cluster, instance string - - // Expected format: projects/{project}/locations/{location}/clusters/{cluster} - // or projects/{project}/locations/{location}/clusters/{cluster}/instances/{instance} - if len(parts) < 6 || parts[0] != "projects" || parts[2] != "locations" || parts[4] != "clusters" { - return "", false - } - - project = parts[1] - region = parts[3] - cluster = parts[5] - - if len(parts) >= 8 && parts[6] == "instances" { - instance = parts[7] - } else { - return "", false - } - - tmpl, err := template.New("alloydb-connection").Parse(alloyDBConnectionMessageTemplate) - if err != nil { - // This should not happen with a static template - return fmt.Sprintf("template parsing error: %v", err), false - } - - data := struct { - Project string - Region string - Cluster string - Instance string - }{ - Project: project, - Region: region, - Cluster: cluster, - Instance: instance, - } - - var b strings.Builder - if err := tmpl.Execute(&b, data); err != nil { - return fmt.Sprintf("template execution error: %v", err), false - } - - return b.String(), true -} - // ParseParams parses the parameters for the tool. func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { return parameters.ParseParams(t.AllParams, data, claims) @@ -363,10 +286,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/alloydbainl/alloydbainl.go b/internal/tools/alloydbainl/alloydbainl.go index e7bc102c30..8c3b468091 100644 --- a/internal/tools/alloydbainl/alloydbainl.go +++ b/internal/tools/alloydbainl/alloydbainl.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - numParams := len(cfg.NLConfigParameters) quotedNameParts := make([]string, 0, numParams) placeholderParts := make([]string, 0, numParams) @@ -126,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) Config: cfg, Parameters: cfg.NLConfigParameters, Statement: stmt, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.NLConfigParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -139,9 +121,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *pgxpool.Pool + Parameters parameters.Parameters `yaml:"parameters"` Statement string manifest tools.Manifest mcpManifest tools.McpManifest @@ -151,7 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sliceParams := params.AsSlice() allParamValues := make([]any, len(sliceParams)+1) allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question @@ -160,31 +145,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT allParamValues[i+2] = fmt.Sprintf("%s", param) } - results, err := t.Pool.Query(ctx, t.Statement, allParamValues...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues) - } - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, f := range fields { - vMap[f.Name] = v[i] - } - out = append(out, vMap) - } - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, t.Statement, allParamValues) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { @@ -203,10 +164,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 25c9439986..61b90a1d11 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -57,11 +57,6 @@ type compatibleSource interface { BigQuerySession() bigqueryds.BigQuerySessionProvider } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -136,17 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, - SessionProvider: s.BigQuerySession(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -156,17 +144,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - SessionProvider bigqueryds.BigQuerySessionProvider - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -174,24 +154,28 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke runs the contribution analysis. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() inputData, ok := paramsMap["input_data"].(string) if !ok { return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) } - bqClient := t.Client - restService := t.RestService - var err error + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -229,9 +213,9 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT var inputDataSource string trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData)) if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { var connProps []*bigqueryapi.ConnectionProperty - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, inputData, nil, connProps) + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, inputData, nil, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } @@ -252,7 +236,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { - if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { + if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { return nil, fmt.Errorf("query in input_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) } } @@ -262,18 +246,18 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } inputDataSource = fmt.Sprintf("(%s)", inputData) } else { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { parts := strings.Split(inputData, ".") var projectID, datasetID string switch len(parts) { case 3: // project.dataset.table projectID, datasetID = parts[0], parts[1] case 2: // dataset.table - projectID, datasetID = t.Client.Project(), parts[0] + projectID, datasetID = source.BigQueryClient().Project(), parts[0] default: return nil, fmt.Errorf("invalid table ID format for 'input_data': %q. Expected 'dataset.table' or 'project.dataset.table'", inputData) } - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, inputData) } } @@ -292,7 +276,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // Get session from provider if in protected mode. // Otherwise, a new session will be created by the first query. - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -385,10 +369,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerycommon/conversion_test.go b/internal/tools/bigquery/bigquerycommon/conversion_test.go new file mode 100644 index 0000000000..c735d0ebe1 --- /dev/null +++ b/internal/tools/bigquery/bigquerycommon/conversion_test.go @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bigquerycommon + +import ( + "math/big" + "reflect" + "testing" +) + +func TestNormalizeValue(t *testing.T) { + tests := []struct { + name string + input any + expected any + }{ + { + name: "big.Rat 1/3 (NUMERIC scale 9)", + input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333... + expected: "0.33333333333333333333333333333333333333", // FloatString(38) + }, + { + name: "big.Rat 19/2 (9.5)", + input: new(big.Rat).SetFrac64(19, 2), + expected: "9.5", + }, + { + name: "big.Rat 12341/10 (1234.1)", + input: new(big.Rat).SetFrac64(12341, 10), + expected: "1234.1", + }, + { + name: "big.Rat 10/1 (10)", + input: new(big.Rat).SetFrac64(10, 1), + expected: "10", + }, + { + name: "string", + input: "hello", + expected: "hello", + }, + { + name: "int", + input: 123, + expected: 123, + }, + { + name: "nested slice of big.Rat", + input: []any{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "nested map of big.Rat", + input: map[string]any{ + "val1": new(big.Rat).SetFrac64(19, 2), + "val2": new(big.Rat).SetFrac64(1, 2), + }, + expected: map[string]any{ + "val1": "9.5", + "val2": "0.5", + }, + }, + { + name: "complex nested structure", + input: map[string]any{ + "list": []any{ + map[string]any{ + "rat": new(big.Rat).SetFrac64(3, 2), + }, + }, + }, + expected: map[string]any{ + "list": []any{ + map[string]any{ + "rat": "1.5", + }, + }, + }, + }, + { + name: "slice of *big.Rat", + input: []*big.Rat{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "slice of strings", + input: []string{"a", "b"}, + expected: []any{"a", "b"}, + }, + { + name: "byte slice (BYTES)", + input: []byte("hello"), + expected: []byte("hello"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := NormalizeValue(tt.input) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index 5486ac36ed..d9b6fd0283 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -17,6 +17,8 @@ package bigquerycommon import ( "context" "fmt" + "math/big" + "reflect" "sort" "strings" @@ -118,3 +120,54 @@ func InitializeDatasetParameters( return projectParam, datasetParam } + +// NormalizeValue converts BigQuery specific types to standard JSON-compatible types. +// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting +// them to decimal strings with up to 38 digits of precision, trimming trailing zeros. +// It recursively handles slices (arrays) and maps (structs) using reflection. +func NormalizeValue(v any) any { + if v == nil { + return nil + } + + // Handle *big.Rat specifically. + if rat, ok := v.(*big.Rat); ok { + // Convert big.Rat to a decimal string. + // Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC) + // and trim trailing zeros to match BigQuery's behavior. + s := rat.FloatString(38) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s + } + + // Use reflection for slices and maps to handle various underlying types. + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + // Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior). + if rv.Type().Elem().Kind() == reflect.Uint8 { + return v + } + newSlice := make([]any, rv.Len()) + for i := 0; i < rv.Len(); i++ { + newSlice[i] = NormalizeValue(rv.Index(i).Interface()) + } + return newSlice + case reflect.Map: + // Ensure keys are strings to produce a JSON-compatible map. + if rv.Type().Key().Kind() != reflect.String { + return v + } + newMap := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface()) + } + return newMap + } + + return v +} diff --git a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go index a1c5119f08..6d54f000b1 100644 --- a/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go +++ b/internal/tools/bigquery/bigqueryconversationalanalytics/bigqueryconversationalanalytics.go @@ -26,7 +26,6 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -105,11 +104,6 @@ type CAPayload struct { ClientIdEnum string `json:"clientIdEnum"` } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -135,7 +129,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -153,31 +147,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{userQueryParameter, tableRefsParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) - // Get cloud-platform token source for Gemini Data Analytics API during initialization - var bigQueryTokenSourceWithScope oauth2.TokenSource - if !s.UseClientAuthorization() { - ctx := context.Background() - ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform") - if err != nil { - return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err) - } - bigQueryTokenSourceWithScope = ts - } - // finish tool setup t := Tool{ - Config: cfg, - Project: s.BigQueryProject(), - Location: s.BigQueryLocation(), - Parameters: params, - Client: s.BigQueryClient(), - UseClientOAuth: s.UseClientAuthorization(), - TokenSource: bigQueryTokenSourceWithScope, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, - MaxQueryResultRows: s.GetMaxQueryResultRows(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -187,30 +162,25 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project string - Location string - Client *bigqueryapi.Client - TokenSource oauth2.TokenSource - manifest tools.Manifest - mcpManifest tools.McpManifest - MaxQueryResultRows int - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var tokenStr string - var err error // Get credentials for the API call - if t.UseClientOAuth { + if source.UseClientAuthorization() { // Use client-side access token if accessToken == "" { return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", util.ErrUnauthorized) @@ -220,11 +190,17 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error parsing access token: %w", err) } } else { + // Get cloud-platform token source for Gemini Data Analytics API during initialization + tokenSource, err := source.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform") + if err != nil { + return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err) + } + // Use cloud-platform token source for Gemini Data Analytics API - if t.TokenSource == nil { + if tokenSource == nil { return nil, fmt.Errorf("cloud-platform token source is missing") } - token, err := t.TokenSource.Token() + token, err := tokenSource.Token() if err != nil { return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err) } @@ -245,17 +221,17 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { for _, tableRef := range tableRefs { - if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { + if !source.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID) } } } // Construct URL, headers, and payload - projectID := t.Project - location := t.Location + projectID := source.BigQueryProject() + location := source.BigQueryLocation() if location == "" { location = "us" } @@ -279,7 +255,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Call the streaming API - response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows) + response, err := getStream(caURL, payload, headers, source.GetMaxQueryResultRows()) if err != nil { return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err) } @@ -303,8 +279,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } // StreamMessage represents a single message object from the streaming API response. @@ -580,6 +560,6 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s return append(messages, newMessage) } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index c1bbeadaf8..a70d4d342d 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -60,11 +60,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -90,7 +85,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } var sqlDescriptionBuilder strings.Builder @@ -136,18 +131,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - WriteMode: s.BigQueryWriteMode(), - SessionProvider: s.BigQuerySession(), - IsDatasetAllowed: s.IsDatasetAllowed, - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -157,25 +144,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - WriteMode string - SessionProvider bigqueryds.BigQuerySessionProvider - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -186,17 +169,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) } - bqClient := t.Client - restService := t.RestService + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() - var err error // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -204,8 +186,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT var connProps []*bigqueryapi.ConnectionProperty var session *bigqueryds.Session - if t.WriteMode == bigqueryds.WriteModeProtected { - session, err = t.SessionProvider(ctx) + if source.BigQueryWriteMode() == bigqueryds.WriteModeProtected { + session, err = source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session for protected mode: %w", err) } @@ -221,7 +203,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT statementType := dryRunJob.Statistics.Query.StatementType - switch t.WriteMode { + switch source.BigQueryWriteMode() { case bigqueryds.WriteModeBlocked: if statementType != "SELECT" { return nil, fmt.Errorf("write mode is 'blocked', only SELECT statements are allowed") @@ -235,7 +217,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { switch statementType { case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA": return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType) @@ -270,7 +252,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } else if statementType != "SELECT" { // If dry run yields no tables, fall back to the parser for non-SELECT statements // to catch unsafe operations like EXECUTE IMMEDIATE. - parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project()) + parsedTables, parseErr := bqutil.TableParser(sql, source.BigQueryClient().Project()) if parseErr != nil { // If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail. return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr) @@ -282,7 +264,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT parts := strings.Split(tableID, ".") if len(parts) == 3 { projectID, datasetID := parts[0], parts[1] - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID) } } @@ -337,7 +319,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT schema := it.Schema row := orderedmap.Row{} for i, field := range schema { - row.Add(field.Name, val[i]) + row.Add(field.Name, bqutil.NormalizeValue(val[i])) } out = append(out, row) } @@ -374,10 +356,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 1e4262b995..034bce3501 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -57,11 +57,6 @@ type compatibleSource interface { BigQuerySession() bigqueryds.BigQuerySessionProvider } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -87,7 +82,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allowedDatasets := s.BigQueryAllowedDatasets() @@ -116,17 +111,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - IsDatasetAllowed: s.IsDatasetAllowed, - SessionProvider: s.BigQuerySession(), - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -136,24 +124,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - AllowedDatasets []string - SessionProvider bigqueryds.BigQuerySessionProvider - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() historyData, ok := paramsMap["history_data"].(string) if !ok { @@ -188,17 +173,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - bqClient := t.Client - restService := t.RestService - var err error + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, false) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -207,9 +191,9 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT var historyDataSource string trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData)) if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { var connProps []*bigqueryapi.ConnectionProperty - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -218,7 +202,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT {Key: "session_id", Value: session.ID}, } } - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, t.Client.Project(), t.Client.Location, historyData, nil, connProps) + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, source.BigQueryClient().Project(), source.BigQueryClient().Location, historyData, nil, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } @@ -230,7 +214,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT queryStats := dryRunJob.Statistics.Query if queryStats != nil { for _, tableRef := range queryStats.ReferencedTables { - if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { + if !source.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) { return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId) } } @@ -240,7 +224,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } historyDataSource = fmt.Sprintf("(%s)", historyData) } else { - if len(t.AllowedDatasets) > 0 { + if len(source.BigQueryAllowedDatasets()) > 0 { parts := strings.Split(historyData, ".") var projectID, datasetID string @@ -249,13 +233,13 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT projectID = parts[0] datasetID = parts[1] case 2: // dataset.table - projectID = t.Client.Project() + projectID = source.BigQueryClient().Project() datasetID = parts[0] default: return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData) } - if !t.IsDatasetAllowed(projectID, datasetID) { + if !source.IsDatasetAllowed(projectID, datasetID) { return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData) } } @@ -279,7 +263,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // JobStatistics.QueryStatistics.StatementType query := bqClient.Query(sql) query.Location = bqClient.Location - session, err := t.SessionProvider(ctx) + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -349,10 +333,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index e850df70f3..b083c49e2c 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -54,11 +54,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -84,7 +79,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -104,14 +99,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -121,22 +112,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - IsDatasetAllowed func(projectID, datasetID string) bool - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -148,22 +138,21 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - bqClient := t.Client - var err error + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } @@ -193,10 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go index d903a4b9bf..b896244ed0 100644 --- a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go +++ b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go @@ -55,11 +55,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -108,14 +103,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -125,22 +116,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - IsDatasetAllowed func(projectID, datasetID string) bool - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -157,20 +147,19 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey) } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := t.Client + bqClient := source.BigQueryClient() - var err error // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -203,10 +192,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go index 5b77282d4f..dafe9b2246 100644 --- a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go +++ b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go @@ -52,11 +52,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -82,7 +77,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } var projectParameter parameters.Parameter @@ -103,14 +98,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - AllowedDatasets: allowedDatasets, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -120,24 +111,23 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - Statement string - AllowedDatasets []string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - if len(t.AllowedDatasets) > 0 { - return t.AllowedDatasets, nil +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + if len(source.BigQueryAllowedDatasets()) > 0 { + return source.BigQueryAllowedDatasets(), nil } mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) @@ -145,14 +135,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) } - bqClient := t.Client + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -197,10 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go index eed9ea10e9..11987c6dac 100644 --- a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go +++ b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go @@ -55,11 +55,6 @@ type compatibleSource interface { BigQueryAllowedDatasets() []string } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,7 +80,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } defaultProjectID := s.BigQueryProject() @@ -107,14 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - ClientCreator: s.BigQueryClientCreator(), - Client: s.BigQueryClient(), - IsDatasetAllowed: s.IsDatasetAllowed, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -124,22 +115,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Client *bigqueryapi.Client - ClientCreator bigqueryds.BigqueryClientCreator - IsDatasetAllowed func(projectID, datasetID string) bool - Statement string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() projectId, ok := mapParams[projectKey].(string) if !ok { @@ -151,18 +141,18 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - if !t.IsDatasetAllowed(projectId, datasetId) { + if !source.IsDatasetAllowed(projectId, datasetId) { return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := t.Client + bqClient := source.BigQueryClient() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, _, err = t.ClientCreator(tokenStr, false) + bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -208,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go index baeda85773..e134e9f298 100644 --- a/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go +++ b/internal/tools/bigquery/bigquerysearchcatalog/bigquerysearchcatalog.go @@ -51,11 +51,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,20 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - - // Get the Dataplex client using the method from the source - makeCatalogClient := s.MakeDataplexCatalogClient() - prompt := parameters.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.") datasetIds := parameters.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", parameters.NewStringParameter("datasetId", "The IDs of the bigquery dataset.")) projectIds := parameters.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", parameters.NewStringParameter("projectId", "The IDs of the bigquery project.")) @@ -100,11 +81,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - MakeCatalogClient: makeCatalogClient, - ProjectID: s.BigQueryProject(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -117,12 +95,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - UseClientOAuth bool - MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error) - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -133,8 +108,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } func constructSearchQueryHelper(predicate string, operator string, items []string) string { @@ -206,7 +185,12 @@ func ExtractType(resourceString string) string { return typeMap[resourceString[lastIndex+1:]] } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() pageSize := int32(paramsMap["pageSize"].(int)) prompt, _ := paramsMap["prompt"].(string) @@ -228,14 +212,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT req := &dataplexpb.SearchEntriesRequest{ Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)), - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.BigQueryProject()), PageSize: pageSize, SemanticSearch: true, } - catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient() + catalogClient, dataplexClientCreator, _ := source.MakeDataplexCatalogClient()() - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) @@ -248,7 +232,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT it := catalogClient.SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.BigQueryProject()) } var results []Response @@ -288,6 +272,6 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index 2a4f78249e..fa02f658eb 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -57,11 +57,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &bigqueryds.Source{} - -var compatibleSources = [...]string{bigqueryds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,18 +76,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -102,15 +85,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - AllParams: allParameters, - UseClientOAuth: s.UseClientAuthorization(), - Client: s.BigQueryClient(), - RestService: s.BigQueryRestService(), - SessionProvider: s.BigQuerySession(), - ClientCreator: s.BigQueryClientCreator(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -120,22 +98,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - AllParams parameters.Parameters `yaml:"allParams"` - - Client *bigqueryapi.Client - RestService *bigqueryrestapi.Service - SessionProvider bigqueryds.BigQuerySessionProvider - ClientCreator bigqueryds.BigqueryClientCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters)) lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters)) @@ -212,16 +189,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT lowLevelParams = append(lowLevelParams, lowLevelParam) } - bqClient := t.Client - restService := t.RestService + bqClient := source.BigQueryClient() + restService := source.BigQueryRestService() // Initialize new client if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - bqClient, restService, err = t.ClientCreator(tokenStr, true) + bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) if err != nil { return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) } @@ -232,8 +209,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT query.Location = bqClient.Location connProps := []*bigqueryapi.ConnectionProperty{} - if t.SessionProvider != nil { - session, err := t.SessionProvider(ctx) + if source.BigQuerySession() != nil { + session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } @@ -274,7 +251,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } vMap := make(map[string]any) for key, value := range row { - vMap[key] = value + vMap[key] = bqutil.NormalizeValue(value) } out = append(out, vMap) } @@ -311,10 +288,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/bigtable/bigtable.go b/internal/tools/bigtable/bigtable.go index 52b6be1de0..fe93630f95 100644 --- a/internal/tools/bigtable/bigtable.go +++ b/internal/tools/bigtable/bigtable.go @@ -21,7 +21,6 @@ import ( "cloud.google.com/go/bigtable" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigtabledb "github.com/googleapis/genai-toolbox/internal/sources/bigtable" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { BigtableClient() *bigtable.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &bigtabledb.Source{} - -var compatibleSources = [...]string{bigtabledb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.BigtableClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -105,9 +86,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Client *bigtable.Client + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -155,7 +134,12 @@ func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValu return btParamTypes, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -172,7 +156,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("fail to get map params: %w", err) } - ps, err := t.Client.PrepareStatement( + ps, err := source.BigtableClient().PrepareStatement( ctx, newStatement, mapParamsType, @@ -224,10 +208,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cassandra/cassandracql/cassandracql.go b/internal/tools/cassandra/cassandracql/cassandracql.go index bfcf4d883c..a05d0815ba 100644 --- a/internal/tools/cassandra/cassandracql/cassandracql.go +++ b/internal/tools/cassandra/cassandracql/cassandracql.go @@ -21,7 +21,6 @@ import ( gocql "github.com/apache/cassandra-gocql-driver/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cassandra" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,10 +45,6 @@ type compatibleSource interface { CassandraSession() *gocql.Session } -var _ compatibleSource = &cassandra.Source{} - -var compatibleSources = [...]string{cassandra.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,20 +56,15 @@ type Config struct { TemplateParameters parameters.Parameters `yaml:"templateParameters"` } +var _ tools.ToolConfig = Config{} + +// ToolConfigKind implements tools.ToolConfig. +func (c Config) ToolConfigKind() string { + return kind +} + // Initialize implements tools.ToolConfig. func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[c.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", c.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(c.TemplateParameters, c.Parameters) if err != nil { return nil, err @@ -85,25 +75,17 @@ func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { t := Tool{ Config: c, AllParams: allParameters, - Session: s.CassandraSession(), manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired}, mcpManifest: mcpManifest, } return t, nil } -// ToolConfigKind implements tools.ToolConfig. -func (c Config) ToolConfigKind() string { - return kind -} - -var _ tools.ToolConfig = Config{} +var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Session *gocql.Session + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -113,8 +95,8 @@ func (t Tool) ToConfig() tools.ToolConfig { } // RequiresClientAuthorization implements tools.Tool. -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // Authorized implements tools.Tool. @@ -123,7 +105,12 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { } // Invoke implements tools.Tool. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -135,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - iter := t.Session.Query(newStatement, sliceParams...).IterContext(ctx) + iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx) // Create a slice to store the out var out []map[string]interface{} @@ -170,8 +157,6 @@ func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) return parameters.ParseParams(t.AllParams, data, claims) } -var _ tools.Tool = Tool{} - -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go index d42c2c76d0..826d20d482 100644 --- a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go +++ b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const executeSQLKind string = "clickhouse-execute-sql" func init() { @@ -47,6 +41,10 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,16 +60,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", executeSQLKind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL statement to execute.") params := parameters.Parameters{sqlParameter} @@ -80,7 +68,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -91,9 +78,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -102,14 +87,19 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"]) } - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.ClickHousePool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -183,10 +173,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go index 15fe368e8b..e6df548907 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const listDatabasesKind string = "clickhouse-list-databases" func init() { @@ -47,6 +41,10 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,23 +61,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources) - } - allParameters, paramManifest, _ := parameters.ProcessParameters(nil, cfg.Parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -90,9 +77,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -101,11 +86,16 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Query to list all databases query := "SHOW DATABASES" - results, err := t.Pool.QueryContext(ctx, query) + results, err := source.ClickHousePool().QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -146,10 +136,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go index 768b942b41..ca6d9b21b7 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases_test.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" - "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -32,21 +31,6 @@ func TestListDatabasesConfigToolConfigKind(t *testing.T) { } } -func TestListDatabasesConfigInitializeMissingSource(t *testing.T) { - cfg := Config{ - Name: "test-list-databases", - Kind: listDatabasesKind, - Source: "missing-source", - Description: "Test list databases tool", - } - - srcs := map[string]sources.Source{} - _, err := cfg.Initialize(srcs) - if err == nil { - t.Error("expected error for missing source") - } -} - func TestParseFromYamlClickHouseListDatabases(t *testing.T) { ctx, err := testutils.ContextWithNewLogger() if err != nil { diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go index ca98bbea30..e882a88ea5 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go @@ -25,12 +25,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const listTablesKind string = "clickhouse-list-tables" const databaseKey string = "database" @@ -48,6 +42,10 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,16 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listTablesKind, compatibleSources) - } - databaseParameter := parameters.NewStringParameter(databaseKey, "The database to list tables from.") params := parameters.Parameters{databaseParameter} @@ -83,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -94,9 +81,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -105,7 +90,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() database, ok := mapParams[databaseKey].(string) if !ok { @@ -115,13 +105,13 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token t // Query to list all tables in the specified database query := fmt.Sprintf("SHOW TABLES FROM %s", database) - results, err := t.Pool.QueryContext(ctx, query) + results, err := source.ClickHousePool().QueryContext(ctx, query) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } defer results.Close() - var tables []map[string]any + tables := []map[string]any{} for results.Next() { var tableName string err := results.Scan(&tableName) @@ -157,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go index 2705ded3fc..4500dac099 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables_test.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" - "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -32,21 +31,6 @@ func TestListTablesConfigToolConfigKind(t *testing.T) { } } -func TestListTablesConfigInitializeMissingSource(t *testing.T) { - cfg := Config{ - Name: "test-list-tables", - Kind: listTablesKind, - Source: "missing-source", - Description: "Test list tables tool", - } - - srcs := map[string]sources.Source{} - _, err := cfg.Initialize(srcs) - if err == nil { - t.Error("expected error for missing source") - } -} - func TestParseFromYamlClickHouseListTables(t *testing.T) { ctx, err := testutils.ContextWithNewLogger() if err != nil { diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql.go b/internal/tools/clickhouse/clickhousesql/clickhousesql.go index 969eadef66..83a2f1ee9d 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql.go @@ -25,21 +25,15 @@ import ( "github.com/googleapis/genai-toolbox/internal/util/parameters" ) -type compatibleSource interface { - ClickHousePool() *sql.DB -} - -var compatibleSources = []string{"clickhouse"} - const sqlKind string = "clickhouse-sql" func init() { - if !tools.Register(sqlKind, newSQLConfig) { + if !tools.Register(sqlKind, newConfig) { panic(fmt.Sprintf("tool kind %q already registered", sqlKind)) } } -func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { actual := Config{Name: name} if err := decoder.DecodeContext(ctx, &actual); err != nil { return nil, err @@ -47,6 +41,10 @@ func newSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tool return actual, nil } +type compatibleSource interface { + ClickHousePool() *sql.DB +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,23 +63,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", sqlKind, compatibleSources) - } - allParameters, paramManifest, _ := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.ClickHousePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -93,7 +80,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -102,7 +88,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, token tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -115,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, token t } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -191,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go b/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go index 4c127bd734..3c50305e28 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql_test.go @@ -142,66 +142,6 @@ func TestSQLConfigInitializeValidSource(t *testing.T) { } } -func TestSQLConfigInitializeMissingSource(t *testing.T) { - config := Config{ - Name: "test-tool", - Kind: sqlKind, - Source: "missing-source", - Description: "Test tool", - Statement: "SELECT 1", - Parameters: parameters.Parameters{}, - } - - sources := map[string]sources.Source{} - - _, err := config.Initialize(sources) - if err == nil { - t.Fatal("Expected error for missing source, got nil") - } - - expectedErr := `no source named "missing-source" configured` - if err.Error() != expectedErr { - t.Errorf("Expected error %q, got %q", expectedErr, err.Error()) - } -} - -// mockIncompatibleSource is a mock source that doesn't implement the compatibleSource interface -type mockIncompatibleSource struct{} - -func (m *mockIncompatibleSource) SourceKind() string { - return "mock" -} - -func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig { - return nil -} - -func TestSQLConfigInitializeIncompatibleSource(t *testing.T) { - config := Config{ - Name: "test-tool", - Kind: sqlKind, - Source: "incompatible-source", - Description: "Test tool", - Statement: "SELECT 1", - Parameters: parameters.Parameters{}, - } - - mockSource := &mockIncompatibleSource{} - - sources := map[string]sources.Source{ - "incompatible-source": mockSource, - } - - _, err := config.Initialize(sources) - if err == nil { - t.Fatal("Expected error for incompatible source, got nil") - } - - if err.Error() == "" { - t.Error("Expected non-empty error message") - } -} - func TestToolManifest(t *testing.T) { tool := Tool{ manifest: tools.Manifest{ diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go new file mode 100644 index 0000000000..bf54c26c3f --- /dev/null +++ b/internal/tools/cloudgda/cloudgda.go @@ -0,0 +1,206 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind string = "cloud-gemini-data-analytics-query" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GetProjectID() string + GetBaseURL() string + UseClientAuthorization() bool + GetClient(context.Context, string) (*http.Client, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Location string `yaml:"location" validate:"required"` + Context *QueryDataContext `yaml:"context" validate:"required"` + GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // Define the parameters for the Gemini Data Analytics Query API + // The prompt is the only input parameter. + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true), + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + return Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +// Invoke executes the tool logic +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + prompt, ok := paramsMap["prompt"].(string) + if !ok { + return nil, fmt.Errorf("prompt parameter not found or not a string") + } + + // The API endpoint itself always uses the "global" location. + apiLocation := "global" + apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation) + apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent) + + // The parent in the request payload uses the tool's configured location. + payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location) + + payload := &QueryDataRequest{ + Parent: payloadParent, + Prompt: prompt, + Context: t.Context, + GenerationOptions: t.GenerationOptions, + } + + bodyBytes, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request payload: %w", err) + } + + // Parse the access token if provided + var tokenStr string + if source.UseClientAuthorization() { + var err error + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + } + + client, err := source.GetClient(ctx, tokenStr) + if err != nil { + return nil, fmt.Errorf("failed to get HTTP client: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return result, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/cloudgda/cloudgda_test.go b/internal/tools/cloudgda/cloudgda_test.go new file mode 100644 index 0000000000..0d57032904 --- /dev/null +++ b/internal/tools/cloudgda/cloudgda_test.go @@ -0,0 +1,353 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/server/resources" + "github.com/googleapis/genai-toolbox/internal/sources" + cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +func TestParseFromYaml(t *testing.T) { + t.Parallel() + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + my-gda-query-tool: + kind: cloud-gemini-data-analytics-query + source: gda-api-source + description: Test Description + location: us-central1 + context: + datasourceReferences: + spannerReference: + databaseReference: + projectId: "cloud-db-nl2sql" + region: "us-central1" + instanceId: "evalbench" + databaseId: "financial" + engine: "GOOGLE_SQL" + agentContextReference: + contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates" + generationOptions: + generateQueryResult: true + `, + want: map[string]tools.ToolConfig{ + "my-gda-query-tool": cloudgdatool.Config{ + Name: "my-gda-query-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "gda-api-source", + Description: "Test Description", + Location: "us-central1", + AuthRequired: []string{}, + Context: &cloudgdatool.QueryDataContext{ + DatasourceReferences: &cloudgdatool.DatasourceReferences{ + SpannerReference: &cloudgdatool.SpannerReference{ + DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ + ProjectID: "cloud-db-nl2sql", + Region: "us-central1", + InstanceID: "evalbench", + DatabaseID: "financial", + Engine: cloudgdatool.SpannerEngineGoogleSQL, + }, + AgentContextReference: &cloudgdatool.AgentContextReference{ + ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, + }, + }, + GenerationOptions: &cloudgdatool.GenerationOptions{ + GenerateQueryResult: true, + }, + }, + }, + }, + } + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Tools) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools) + } + }) + } +} + +// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header. +type authRoundTripper struct { + Token string + Next http.RoundTripper +} + +func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + newReq := *req + newReq.Header = make(http.Header) + for k, v := range req.Header { + newReq.Header[k] = v + } + newReq.Header.Set("Authorization", rt.Token) + if rt.Next == nil { + return http.DefaultTransport.RoundTrip(&newReq) + } + return rt.Next.RoundTrip(&newReq) +} + +type mockSource struct { + kind string + client *http.Client // Can be used to inject a specific client + baseURL string // BaseURL is needed to implement sources.Source.BaseURL + config cloudgdasrc.Config // to return from ToConfig +} + +func (m *mockSource) SourceKind() string { return m.kind } +func (m *mockSource) ToConfig() sources.SourceConfig { return m.config } +func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) { + if m.client != nil { + return m.client, nil + } + // Default client for testing if not explicitly set + transport := &http.Transport{} + authTransport := &authRoundTripper{ + Token: "Bearer test-access-token", // Dummy token + Next: transport, + } + return &http.Client{Transport: authTransport}, nil +} +func (m *mockSource) UseClientAuthorization() bool { return false } +func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) { + return m, nil +} +func (m *mockSource) BaseURL() string { return m.baseURL } + +func TestInitialize(t *testing.T) { + t.Parallel() + + srcs := map[string]sources.Source{ + "gda-api-source": &cloudgdasrc.Source{ + Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"}, + Client: &http.Client{}, + BaseURL: cloudgdasrc.Endpoint, + }, + } + + tcs := []struct { + desc string + cfg cloudgdatool.Config + }{ + { + desc: "successful initialization", + cfg: cloudgdatool.Config{ + Name: "my-gda-query-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "gda-api-source", + Description: "Test Description", + Location: "us-central1", + }, + }, + } + + // Add an incompatible source for testing + srcs["incompatible-source"] = &mockSource{kind: "another-kind"} + + for _, tc := range tcs { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + tool, err := tc.cfg.Initialize(srcs) + if err != nil { + t.Fatalf("did not expect an error but got: %v", err) + } + // Basic sanity check on the returned tool + _ = tool // Avoid unused variable error + }) + } +} + +func TestInvoke(t *testing.T) { + t.Parallel() + // Mock the HTTP client and server for Invoke testing + serverMux := http.NewServeMux() + // Update expected URL path to include the location "us-central1" + serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST method, got %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type")) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Read and unmarshal the request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read request body: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + var reqPayload cloudgdatool.QueryDataRequest + if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil { + t.Errorf("failed to unmarshal request payload: %v", err) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Verify expected fields + if r.Header.Get("Authorization") == "" { + t.Errorf("expected Authorization header, got empty") + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" { + t.Errorf("unexpected prompt: %s", reqPayload.Prompt) + } + + // Verify payload's parent uses the tool's configured location + if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") { + t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1")) + } + + // Verify context from config + if reqPayload.Context == nil || + reqPayload.Context.DatasourceReferences == nil || + reqPayload.Context.DatasourceReferences.SpannerReference == nil || + reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil || + reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" { + t.Errorf("unexpected context: %v", reqPayload.Context) + } + + // Verify generation options from config + if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult { + t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions) + } + + // Simulate a successful response + resp := map[string]any{ + "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + } + _ = json.NewEncoder(w).Encode(resp) + }) + + mockServer := httptest.NewServer(serverMux) + defer mockServer.Close() + + ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent") + + // Create an authenticated client that uses the mock server + authTransport := &authRoundTripper{ + Token: "Bearer test-access-token", + Next: mockServer.Client().Transport, + } + authClient := &http.Client{Transport: authTransport} + + // Create a real cloudgdasrc.Source but inject the authenticated client + mockGdaSource := &cloudgdasrc.Source{ + Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"}, + Client: authClient, + BaseURL: mockServer.URL, + } + srcs := map[string]sources.Source{ + "mock-gda-source": mockGdaSource, + } + + // Initialize the tool config with context + toolCfg := cloudgdatool.Config{ + Name: "query-data-tool", + Kind: "cloud-gemini-data-analytics-query", + Source: "mock-gda-source", + Description: "Query Gemini Data Analytics", + Location: "us-central1", // Set location for the test + Context: &cloudgdatool.QueryDataContext{ + DatasourceReferences: &cloudgdatool.DatasourceReferences{ + SpannerReference: &cloudgdatool.SpannerReference{ + DatabaseReference: &cloudgdatool.SpannerDatabaseReference{ + ProjectID: "cloud-db-nl2sql", + Region: "us-central1", + InstanceID: "evalbench", + DatabaseID: "financial", + Engine: cloudgdatool.SpannerEngineGoogleSQL, + }, + AgentContextReference: &cloudgdatool.AgentContextReference{ + ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates", + }, + }, + }, + }, + GenerationOptions: &cloudgdatool.GenerationOptions{ + GenerateQueryResult: true, + }, + } + + tool, err := toolCfg.Initialize(srcs) + if err != nil { + t.Fatalf("failed to initialize tool: %v", err) + } + + // Prepare parameters for invocation - ONLY prompt + params := parameters.ParamValues{ + {Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"}, + } + + resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil) + + // Invoke the tool + result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client + if err != nil { + t.Fatalf("tool invocation failed: %v", err) + } + + // Validate the result + expectedResult := map[string]any{ + "queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;", + "naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.", + } + + if !cmp.Equal(expectedResult, result) { + t.Errorf("unexpected result: got %v, want %v", result, expectedResult) + } +} diff --git a/internal/tools/cloudgda/types.go b/internal/tools/cloudgda/types.go new file mode 100644 index 0000000000..8e82cb50c2 --- /dev/null +++ b/internal/tools/cloudgda/types.go @@ -0,0 +1,116 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda + +// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto + +// QueryDataRequest represents the JSON body for the queryData API +type QueryDataRequest struct { + Parent string `json:"parent"` + Prompt string `json:"prompt"` + Context *QueryDataContext `json:"context,omitempty"` + GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"` +} + +// QueryDataContext reflects the proto definition for the query context. +type QueryDataContext struct { + DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"` +} + +// DatasourceReferences reflects the proto definition for datasource references, using a oneof. +type DatasourceReferences struct { + SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"` + AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"` + CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"` +} + +// SpannerReference reflects the proto definition for Spanner database reference. +type SpannerReference struct { + DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// SpannerDatabaseReference reflects the proto definition for a Spanner database reference. +type SpannerDatabaseReference struct { + Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"` + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// SpannerEngine represents the engine of the Spanner instance. +type SpannerEngine string + +const ( + SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED" + SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL" + SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL" +) + +// AlloyDBReference reflects the proto definition for an AlloyDB database reference. +type AlloyDBReference struct { + DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference. +type AlloyDBDatabaseReference struct { + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// CloudSQLReference reflects the proto definition for a Cloud SQL database reference. +type CloudSQLReference struct { + DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"` + AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"` +} + +// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference. +type CloudSQLDatabaseReference struct { + Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"` + ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"` + Region string `json:"region,omitempty" yaml:"region,omitempty"` + InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"` + DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"` + TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"` +} + +// CloudSQLEngine represents the engine of the Cloud SQL instance. +type CloudSQLEngine string + +const ( + CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED" + CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL" + CloudSQLEngineMySQL CloudSQLEngine = "MYSQL" +) + +// AgentContextReference reflects the proto definition for agent context. +type AgentContextReference struct { + ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"` +} + +// GenerationOptions reflects the proto definition for generation options. +type GenerationOptions struct { + GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"` + GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"` + GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"` + GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"` +} diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go index 21b0ee2c8e..025ca9310f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage/cloudhealthcarefhirfetchpage.go @@ -62,11 +62,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,35 +78,16 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - urlParameter := parameters.NewStringParameter(pageURLKey, "The full URL of the FHIR page to fetch. This would be the value of `Bundle.entry.link.url` field within the response returned from FHIR search or FHIR patient everything operations.") params := parameters.Parameters{urlParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -121,28 +97,28 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + url, ok := params.AsMap()[pageURLKey].(string) if !ok { return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey) } var httpClient *http.Client - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) @@ -150,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr}) httpClient = oauth2.NewClient(ctx, ts) } else { - // The t.Service object holds a client with the default credentials. + // The source.Service() object holds a client with the default credentials. // However, the client is not exported, so we have to create a new one. var err error httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope) @@ -201,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go index bedf91f3c1..b00d7c35ac 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything/cloudhealthcarefhirpatienteverything.go @@ -62,11 +62,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } idParameter := parameters.NewStringParameter(patientIDKey, "The ID of the patient FHIR resource for which the information is required") @@ -106,17 +101,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -126,23 +114,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { return nil, err } @@ -151,20 +138,20 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey) } - svc := t.Service + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", t.Project, t.Region, t.Dataset, storeID, patientID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID) var opts []googleapi.CallOption if val, ok := params.AsMap()[typeFilterKey]; ok { types, ok := val.([]any) @@ -225,10 +212,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go index e4fd8f42d7..c1cf43b59f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch/cloudhealthcarefhirpatientsearch.go @@ -78,11 +78,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -108,7 +103,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -140,17 +135,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -160,35 +148,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -261,7 +248,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT opts = append(opts, googleapi.QueryParameter("_summary", "text")) } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search patient resources: %w", err) @@ -298,10 +285,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go index 79dd1ef727..d3386cb657 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdataset/cloudhealthcaregetdataset.go @@ -51,11 +51,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,33 +67,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -108,36 +85,36 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - Project, Region, Dataset string - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) @@ -161,10 +138,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go index 88ad99d3e3..d8da9c096e 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstore/cloudhealthcaregetdicomstore.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,41 +102,40 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go index e2d309445b..03f73dd0a4 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetdicomstoremetrics/cloudhealthcaregetdicomstoremetrics.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,41 +102,40 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go index c4dbe540dc..41c4e71db2 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirresource/cloudhealthcaregetfhirresource.go @@ -59,11 +59,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -89,7 +84,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } typeParameter := parameters.NewStringParameter(typeKey, "The FHIR resource type to retrieve (e.g., Patient, Observation).") @@ -102,17 +97,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -122,23 +110,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) if err != nil { return nil, err } @@ -152,20 +139,20 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey) } - svc := t.Service + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", t.Project, t.Region, t.Dataset, storeID, resType, resID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID) call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name) call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8") resp, err := call.Do() @@ -204,10 +191,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go index f2a3928a4d..1760579b35 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstore/cloudhealthcaregetfhirstore.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,41 +102,40 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go index 867cd6cd65..29e1011da2 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaregetfhirstoremetrics/cloudhealthcaregetfhirstoremetrics.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -83,7 +78,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{} @@ -94,17 +89,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -114,41 +102,40 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", t.Project, t.Region, t.Dataset, storeID) + storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do() if err != nil { return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err) @@ -172,10 +159,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go index 1709cfc59f..e180a8028f 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistdicomstores/cloudhealthcarelistdicomstores.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -111,45 +87,43 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) } var filtered []*healthcare.DicomStore for _, store := range stores.DicomStores { - if len(t.AllowedStores) == 0 { + if len(source.AllowedDICOMStores()) == 0 { filtered = append(filtered, store) continue } @@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT continue } parts := strings.Split(store.Name, "/") - if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok { filtered = append(filtered, store) } } @@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go index ef33387b53..5e9ea52359 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go +++ b/internal/tools/cloudhealthcare/cloudhealthcarelistfhirstores/cloudhealthcarelistfhirstores.go @@ -53,11 +53,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,34 +69,15 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedFHIRStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -111,45 +87,43 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - svc := t.Service - var err error +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } } - datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", t.Project, t.Region, t.Dataset) + datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID()) stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do() if err != nil { return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err) } var filtered []*healthcare.FhirStore for _, store := range stores.FhirStores { - if len(t.AllowedStores) == 0 { + if len(source.AllowedFHIRStores()) == 0 { filtered = append(filtered, store) continue } @@ -157,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT continue } parts := strings.Split(store.Name, "/") - if _, ok := t.AllowedStores[parts[len(parts)-1]]; ok { + if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok { filtered = append(filtered, store) } } @@ -180,10 +154,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go index 29eeb8234d..6272fda5df 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go +++ b/internal/tools/cloudhealthcare/cloudhealthcareretrieverendereddicominstance/cloudhealthcareretrieverendereddicominstance.go @@ -61,11 +61,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,7 +86,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -107,17 +102,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -127,35 +115,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -177,7 +164,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if !ok { return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey) } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame) call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath) call.Header().Set("Accept", "image/jpeg") @@ -214,10 +201,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go index 38b1d93413..afe0f4cc2e 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicominstances/cloudhealthcaresearchdicominstances.go @@ -68,11 +68,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -98,7 +93,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -121,17 +116,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -141,35 +129,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -204,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom instances: %w", err) @@ -244,10 +231,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go index 9425c9103e..0c888f8d9c 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomseries/cloudhealthcaresearchdicomseries.go @@ -65,11 +65,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -95,7 +90,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -117,17 +112,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -137,35 +125,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -187,7 +174,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom series: %w", err) @@ -227,10 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go index 6951dfccd5..8a5e7ccf0d 100644 --- a/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go +++ b/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies/cloudhealthcaresearchdicomstudies.go @@ -63,11 +63,6 @@ type compatibleSource interface { UseClientAuthorization() bool } -// validate compatible sources are still compatible -var _ compatibleSource = &healthcareds.Source{} - -var compatibleSources = [...]string{healthcareds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -93,7 +88,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } params := parameters.Parameters{ @@ -113,17 +108,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - Parameters: params, - Project: s.Project(), - Region: s.Region(), - Dataset: s.DatasetID(), - AllowedStores: s.AllowedDICOMStores(), - UseClientOAuth: s.UseClientAuthorization(), - ServiceCreator: s.ServiceCreator(), - Service: s.Service(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -133,35 +121,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool `yaml:"useClientOAuth"` - Parameters parameters.Parameters `yaml:"parameters"` - - Project, Region, Dataset string - AllowedStores map[string]struct{} - Service *healthcare.Service - ServiceCreator healthcareds.HealthcareServiceCreator - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - storeID, err := common.ValidateAndFetchStoreID(params, t.AllowedStores) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { return nil, err } - svc := t.Service + storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores()) + if err != nil { + return nil, err + } + + svc := source.Service() // Initialize new service if using user OAuth token - if t.UseClientOAuth { + if source.UseClientAuthorization() { tokenStr, err := accessToken.ParseBearerToken() if err != nil { return nil, fmt.Errorf("error parsing access token: %w", err) } - svc, err = t.ServiceCreator(tokenStr) + svc, err = source.ServiceCreator()(tokenStr) if err != nil { return nil, fmt.Errorf("error creating service from OAuth access token: %w", err) } @@ -171,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if err != nil { return nil, err } - name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", t.Project, t.Region, t.Dataset, storeID) + name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID) resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...) if err != nil { return nil, fmt.Errorf("failed to search dicom studies: %w", err) @@ -211,10 +198,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudmonitoring/cloudmonitoring.go b/internal/tools/cloudmonitoring/cloudmonitoring.go index ccf70b4add..54c19f6774 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring.go @@ -23,7 +23,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + BaseURL() string + Client() *http.Client + UserAgent() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -60,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*cloudmonitoringsrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind) - } - // Define the parameters internally instead of from the config file. allParameters := parameters.Parameters{ parameters.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true), @@ -83,9 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - BaseURL: s.BaseURL, - UserAgent: s.UserAgent, - Client: s.Client, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -97,9 +87,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - BaseURL string `yaml:"baseURL"` - UserAgent string - Client *http.Client manifest tools.Manifest mcpManifest tools.McpManifest } @@ -108,7 +95,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() projectID, ok := paramsMap["projectId"].(string) if !ok { @@ -119,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("query parameter not found or not a string") } - url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID) + url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID) req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { @@ -130,9 +122,9 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT q.Add("query", query) req.URL.RawQuery = q.Encode() - req.Header.Set("User-Agent", t.UserAgent) + req.Header.Set("User-Agent", source.UserAgent()) - resp, err := t.Client.Do(req) + resp, err := source.Client().Do(req) if err != nil { return nil, err } @@ -175,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudmonitoring/cloudmonitoring_test.go b/internal/tools/cloudmonitoring/cloudmonitoring_test.go index 4707adafec..51c4d00c21 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring_test.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring_test.go @@ -81,22 +81,6 @@ func TestInitialize(t *testing.T) { AuthRequired: []string{"google-auth-service"}, }, }, - { - desc: "Error: source not found", - cfg: cloudmonitoring.Config{ - Name: "test-tool", - Source: "non-existent-source", - }, - wantErr: `no source named "non-existent-source" configured`, - }, - { - desc: "Error: incompatible source kind", - cfg: cloudmonitoring.Config{ - Name: "test-tool", - Source: "incompatible-source", - }, - wantErr: "invalid source for \"cloud-monitoring-query-prometheus\" tool", - }, } for _, tc := range testCases { diff --git a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go new file mode 100644 index 0000000000..e8f7431f8b --- /dev/null +++ b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go @@ -0,0 +1,210 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlcloneinstance + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + sqladmin "google.golang.org/api/sqladmin/v1" +) + +const kind string = "cloud-sql-clone-instance" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + +// Config defines the configuration for the clone-instance tool. +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Description string `yaml:"description"` + Source string `yaml:"source" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the kind of the tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize initializes the tool from the configuration. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) + } + + project := s.GetDefaultProject() + var projectParam parameters.Parameter + if project != "" { + projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") + } else { + projectParam = parameters.NewStringParameter("project", "The project ID") + } + + allParameters := parameters.Parameters{ + projectParam, + parameters.NewStringParameter("sourceInstanceName", "The name of the instance to be cloned."), + parameters.NewStringParameter("destinationInstanceName", "The name of the new instance that will be created by cloning the source instance."), + // point in time, preferred zone and preferred secondary zone are optional + parameters.NewStringParameterWithRequired("pointInTime", "The timestamp in RFC 3339 format to which the source instance should be cloned.", false), + parameters.NewStringParameterWithRequired("preferredZone", "The preferred zone for the new instance.", false), + parameters.NewStringParameterWithRequired("preferredSecondaryZone", "The preferred secondary zone for the new instance.", false), + } + paramManifest := allParameters.Manifest() + + description := cfg.Description + if description == "" { + description = "Clone an existing Cloud SQL instance into a new instance. The clone can be a direct copy of the source instance, or a point-in-time-recovery (PITR) clone from a specific timestamp. The call returns a Cloud SQL Operation object. Call wait_for_operation tool after this, make sure to use multiplier as 4 to poll the opertation status till it is marked DONE." + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + return Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + }, nil +} + +// Tool represents the clone-instance tool. +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +// Invoke executes the tool's logic. +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() + + project, ok := paramsMap["project"].(string) + if !ok { + return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"]) + } + sourceInstanceName, ok := paramsMap["sourceInstanceName"].(string) + if !ok { + return nil, fmt.Errorf("error casting 'sourceInstanceName' parameter: %v", paramsMap["sourceInstanceName"]) + } + destinationInstanceName, ok := paramsMap["destinationInstanceName"].(string) + if !ok { + return nil, fmt.Errorf("error casting 'destinationInstanceName' parameter: %v", paramsMap["destinationInstanceName"]) + } + + cloneContext := &sqladmin.CloneContext{ + DestinationInstanceName: destinationInstanceName, + } + + pointInTime, ok := paramsMap["pointInTime"].(string) + if ok { + cloneContext.PointInTime = pointInTime + } + preferredZone, ok := paramsMap["preferredZone"].(string) + if ok { + cloneContext.PreferredZone = preferredZone + } + preferredSecondaryZone, ok := paramsMap["preferredSecondaryZone"].(string) + if ok { + cloneContext.PreferredSecondaryZone = preferredSecondaryZone + } + + rb := &sqladmin.InstancesCloneRequest{ + CloneContext: cloneContext, + } + + service, err := source.GetService(ctx, string(accessToken)) + if err != nil { + return nil, err + } + + resp, err := service.Instances.Clone(project, sourceInstanceName, rb).Do() + if err != nil { + return nil, fmt.Errorf("error cloning instance: %w", err) + } + + return resp, nil +} + +// ParseParams parses the parameters for the tool. +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +// Manifest returns the tool's manifest. +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +// McpManifest returns the tool's MCP manifest. +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +// Authorized checks if the tool is authorized. +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return true +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance_test.go b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance_test.go new file mode 100644 index 0000000000..42fb94406d --- /dev/null +++ b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance_test.go @@ -0,0 +1,73 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsqlcloneinstance_test + +import ( + //"context" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + clone-instance-tool: + kind: cloud-sql-clone-instance + description: a test description + source: a-source + `, + want: server.ToolConfigs{ + "clone-instance-tool": cloudsqlcloneinstance.Config{ + Name: "clone-instance-tool", + Kind: "cloud-sql-clone-instance", + Description: "a test description", + Source: "a-source", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go index 08b19e013b..57b4cc06d6 100644 --- a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go +++ b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-database tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -93,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -103,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-database tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -114,7 +117,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -136,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Instance: instance, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -169,10 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go index ccc9b8cf58..148ccfeb6c 100644 --- a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go +++ b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -42,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-user tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -95,7 +100,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -105,7 +109,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-user tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -116,7 +119,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -149,7 +157,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT user.Password = password } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -182,10 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go index 5fd2b41ef3..1fb40b67bc 100644 --- a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go +++ b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go @@ -20,9 +20,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-get-instance" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the get-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -65,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("projectId", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -92,7 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the get-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -113,7 +117,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() projectId, ok := paramsMap["projectId"].(string) @@ -125,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing 'instanceId' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -158,10 +167,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go index 03e3763d01..ba54380631 100644 --- a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go +++ b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go @@ -20,9 +20,9 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-databases" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the list-databases tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladminsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -91,7 +97,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -102,7 +107,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Source *cloudsqladminsrc.Source manifest tools.Manifest mcpManifest tools.McpManifest } @@ -112,7 +116,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -124,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing 'instance' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -176,10 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go index c06c79f961..11ccd91bad 100644 --- a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go +++ b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go @@ -20,9 +20,9 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - cloudsqladminsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-instances" @@ -41,6 +41,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the list-instance tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -64,12 +70,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladminsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -90,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - source: s, AllParams: allParameters, manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -101,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - source *cloudsqladminsrc.Source manifest tools.Manifest mcpManifest tools.McpManifest } @@ -111,7 +115,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -119,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing 'project' parameter") } - service, err := t.source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -169,10 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go index cbb184fe45..672f999282 100644 --- a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go +++ b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go @@ -25,9 +25,9 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-wait-for-operation" @@ -87,6 +87,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the wait-for-operation tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -118,12 +124,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -177,7 +183,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -191,17 +196,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the wait-for-operation tool. type Tool struct { Config - Source *cloudsqladmin.Source - AllParams parameters.Parameters `yaml:"allParams"` + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest // Polling configuration Delay time.Duration MaxDelay time.Duration Multiplier float64 MaxRetries int - - manifest tools.Manifest - mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -209,7 +212,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -221,7 +229,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing 'operation' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -267,7 +275,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("could not unmarshal operation: %w", err) } - if msg, ok := t.generateCloudSQLConnectionMessage(data); ok { + if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok { return msg, nil } return string(opBytes), nil @@ -305,11 +313,15 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) { +func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) { operationType, ok := opResponse["operationType"].(string) if !ok || operationType != "CREATE_DATABASE" { return "", false @@ -329,7 +341,7 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri instance := matches[2] database := matches[3] - instanceData, err := t.fetchInstanceData(context.Background(), project, instance) + instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance) if err != nil { fmt.Printf("error fetching instance data: %v\n", err) return "", false @@ -385,8 +397,8 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri return b.String(), true } -func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) { - service, err := t.Source.GetService(ctx, "") +func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) { + service, err := source.GetService(ctx, "") if err != nil { return nil, err } @@ -408,6 +420,6 @@ func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) ( return data, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go index 89e75c34b4..78bc77d6fa 100644 --- a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go +++ b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -117,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go index e4fa0c7c5b..165a057c35 100644 --- a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go +++ b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -117,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go index 2f56d032e2..224cc3700c 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetDefaultProject() string + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the create-instances tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -66,12 +71,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - s, ok := rawS.(*cloudsqladmin.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) } - project := s.DefaultProject + project := s.GetDefaultProject() var projectParam parameters.Parameter if project != "" { projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.") @@ -96,7 +101,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: s, AllParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -106,7 +110,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool represents the create-instances tool. type Tool struct { Config - Source *cloudsqladmin.Source AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest @@ -117,7 +120,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -167,7 +175,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Project: project, } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } @@ -200,10 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go index 1491cf86f5..156d648e93 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck/cloudsqlpgupgradeprecheck.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" sqladmin "google.golang.org/api/sqladmin/v1" @@ -43,6 +42,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetService(context.Context, string) (*sqladmin.Service, error) + UseClientAuthorization() bool +} + // Config defines the configuration for the precheck-upgrade tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -62,15 +66,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize initializes the tool from the configuration. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - s, ok := rawS.(*cloudsqladmin.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-sql-admin`", kind) - } - allParameters := parameters.Parameters{ parameters.NewStringParameter("project", "The project ID"), parameters.NewStringParameter("instance", "The name of the instance to check"), @@ -88,28 +83,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) return Tool{ - Name: cfg.Name, - Kind: kind, - AuthRequired: cfg.AuthRequired, - Source: s, - AllParams: allParameters, - manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, }, nil } // Tool represents the precheck-upgrade tool. type Tool struct { - Name string `yaml:"name"` - Kind string `yaml:"kind"` - Description string `yaml:"description"` - AuthRequired []string `yaml:"authRequired"` - - Source *cloudsqladmin.Source + Config AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest - Config } // PreCheckResultItem holds the details of a single check result. @@ -145,7 +131,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the tool's logic. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() project, ok := paramsMap["project"].(string) @@ -162,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("missing or empty 'targetDatabaseVersion' parameter") } - service, err := t.Source.GetService(ctx, string(accessToken)) + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, fmt.Errorf("failed to get HTTP client from source: %w", err) } @@ -234,10 +225,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return t.Source.UseClientAuthorization() +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/couchbase/couchbase.go b/internal/tools/couchbase/couchbase.go index 402948ebb2..481c9f6b22 100644 --- a/internal/tools/couchbase/couchbase.go +++ b/internal/tools/couchbase/couchbase.go @@ -22,7 +22,6 @@ import ( "github.com/couchbase/gocb/v2" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/couchbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,11 +47,6 @@ type compatibleSource interface { CouchbaseQueryScanConsistency() uint } -// validate compatible sources are still compatible -var _ compatibleSource = &couchbase.Source{} - -var compatibleSources = [...]string{couchbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -92,12 +74,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup t := Tool{ - Config: cfg, - AllParams: allParameters, - Scope: s.CouchbaseScope(), - QueryScanConsistency: s.CouchbaseQueryScanConsistency(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -107,19 +87,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Scope *gocb.Scope - QueryScanConsistency uint - manifest tools.Manifest - mcpManifest tools.McpManifest + AllParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + namedParamsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap) if err != nil { @@ -130,8 +112,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - results, err := t.Scope.Query(newStatement, &gocb.QueryOptions{ - ScanConsistency: gocb.QueryScanConsistency(t.QueryScanConsistency), + results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{ + ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()), NamedParameters: newParams.AsMap(), }) if err != nil { @@ -166,10 +148,10 @@ func (t Tool) Authorized(verifiedAuthSources []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go index fbab4bed08..daf6d4f29d 100644 --- a/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go +++ b/internal/tools/dataform/dataformcompilelocal/dataformcompilelocal.go @@ -85,7 +85,7 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { paramsMap := params.AsMap() projectDir, ok := paramsMap["project_dir"].(string) @@ -118,10 +118,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go index 4c8b8e09a6..78915c7b96 100644 --- a/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go +++ b/internal/tools/dataplex/dataplexlookupentry/dataplexlookupentry.go @@ -22,7 +22,6 @@ import ( dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { CatalogClient() *dataplexapi.CatalogClient } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - viewDesc := ` ## Argument: view @@ -104,9 +87,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -119,17 +101,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() viewMap := map[int]dataplexpb.EntryView{ 1: dataplexpb.EntryView_BASIC, @@ -153,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Entry: entry, } - result, err := t.CatalogClient.LookupEntry(ctx, req) + result, err := source.CatalogClient().LookupEntry(ctx, req) if err != nil { return nil, err } @@ -179,10 +165,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go index f33215a12e..37f44cf9ea 100644 --- a/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go +++ b/internal/tools/dataplex/dataplexsearchaspecttypes/dataplexsearchaspecttypes.go @@ -23,7 +23,6 @@ import ( "github.com/cenkalti/backoff/v5" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -49,11 +48,6 @@ type compatibleSource interface { ProjectID() string } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,17 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - query := parameters.NewStringParameter("query", "The query against which aspect type should be matched.") pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of returned aspect types in the search page.") orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc") @@ -89,10 +72,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), - ProjectID: s.ProjectID(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -105,18 +86,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Invoke the tool with the provided parameters paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) @@ -126,16 +110,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // Create SearchEntriesRequest with the provided parameters req := &dataplexpb.SearchEntriesRequest{ Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype", - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), PageSize: pageSize, OrderBy: orderBy, SemanticSearch: true, } // Perform the search using the CatalogClient - this will return an iterator - it := t.CatalogClient.SearchEntries(ctx, req) + it := source.CatalogClient().SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) } // Create an instance of exponential backoff with default values for retrying GetAspectType calls @@ -155,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } operation := func() (*dataplexpb.AspectType, error) { - aspectType, err := t.CatalogClient.GetAspectType(ctx, getAspectTypeReq) + aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq) if err != nil { return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err) } @@ -192,10 +176,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go index c4c04e406e..76c3208bbf 100644 --- a/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go +++ b/internal/tools/dataplex/dataplexsearchentries/dataplexsearchentries.go @@ -22,7 +22,6 @@ import ( dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,11 +47,6 @@ type compatibleSource interface { ProjectID() string } -// validate compatible sources are still compatible -var _ compatibleSource = &dataplexds.Source{} - -var compatibleSources = [...]string{dataplexds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,17 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Initialize the search configuration with the provided sources - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - query := parameters.NewStringParameter("query", "The query against which entries in scope should be matched.") pageSize := parameters.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.") orderBy := parameters.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc") @@ -88,10 +71,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) t := Tool{ - Config: cfg, - Parameters: params, - CatalogClient: s.CatalogClient(), - ProjectID: s.ProjectID(), + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -104,18 +85,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) type Tool struct { Config - Parameters parameters.Parameters - CatalogClient *dataplexapi.CatalogClient - ProjectID string - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() query, _ := paramsMap["query"].(string) pageSize := int32(paramsMap["pageSize"].(int)) @@ -123,15 +107,15 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT req := &dataplexpb.SearchEntriesRequest{ Query: query, - Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID), + Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()), PageSize: pageSize, OrderBy: orderBy, SemanticSearch: true, } - it := t.CatalogClient.SearchEntries(ctx, req) + it := source.CatalogClient().SearchEntries(ctx, req) if it == nil { - return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID) + return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID()) } var results []*dataplexpb.SearchEntriesResult @@ -163,10 +147,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index 4a9b5a4fd0..beef9f86a5 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -46,11 +46,6 @@ type compatibleSource interface { DgraphClient() *dgraph.DgraphClient } -// validate compatible sources are still compatible -var _ compatibleSource = &dgraph.Source{} - -var compatibleSources = [...]string{dgraph.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,26 +66,13 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ - Config: cfg, - DgraphClient: s.DgraphClient(), - manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -100,19 +82,23 @@ var _ tools.Tool = Tool{} type Tool struct { Config - DgraphClient *dgraph.DgraphClient - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMapWithDollarPrefix() - resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) + resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) if err != nil { return nil, err } @@ -148,10 +134,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index 2b898da4e1..d7cbb35722 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -43,10 +43,6 @@ type compatibleSource interface { ElasticsearchClient() es.EsClient } -var _ compatibleSource = &es.Source{} - -var compatibleSources = [...]string{es.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -77,29 +73,15 @@ type Tool struct { Config manifest tools.Manifest mcpManifest tools.McpManifest - EsClient es.EsClient } var _ tools.Tool = Tool{} func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - src, ok := srcs[c.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", c.Source) - } - - // verify the source is compatible - s, ok := src.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(c.Name, c.Description, c.AuthRequired, c.Parameters, nil) return Tool{ Config: c, - EsClient: s.ElasticsearchClient(), manifest: tools.Manifest{Description: c.Description, Parameters: c.Parameters.Manifest(), AuthRequired: c.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -119,7 +101,12 @@ type esqlResult struct { Values [][]any `json:"values"` } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var cancel context.CancelFunc if t.Timeout > 0 { ctx, cancel = context.WithTimeout(ctx, time.Duration(t.Timeout)*time.Second) @@ -164,8 +151,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Body: bytes.NewReader(body), Format: t.Format, FilterPath: []string{"columns", "values"}, - Instrument: t.EsClient.InstrumentationEnabled(), - }.Do(ctx, t.EsClient) + Instrument: source.ElasticsearchClient().InstrumentationEnabled(), + }.Do(ctx, source.ElasticsearchClient()) if err != nil { return nil, err @@ -230,10 +217,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go index 1e2531289f..28c8d0fb63 100644 --- a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go +++ b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/firebird" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,10 +46,6 @@ type compatibleSource interface { FirebirdDB() *sql.DB } -var _ compatibleSource = &firebird.Source{} - -var compatibleSources = [...]string{firebird.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -66,16 +61,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -84,7 +69,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Db: s.FirebirdDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -95,9 +79,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Db *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -106,7 +88,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -120,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - rows, err := t.Db.QueryContext(ctx, sql) + rows, err := source.FirebirdDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -180,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firebird/firebirdsql/firebirdsql.go b/internal/tools/firebird/firebirdsql/firebirdsql.go index 07604840cf..9dd040dcd7 100644 --- a/internal/tools/firebird/firebirdsql/firebirdsql.go +++ b/internal/tools/firebird/firebirdsql/firebirdsql.go @@ -22,7 +22,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/firebird" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { FirebirdDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &firebird.Source{} - -var compatibleSources = [...]string{firebird.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.FirebirdDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,9 +87,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,7 +96,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() statement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -142,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - rows, err := t.Db.QueryContext(ctx, statement, namedArgs...) + rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -204,10 +188,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go index f4f2efd589..a1cf8b5bd8 100644 --- a/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go +++ b/internal/tools/firestore/firestoreadddocuments/firestoreadddocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -50,11 +49,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters collectionPathParameter := parameters.NewStringParameter( collectionPathKey, @@ -124,7 +106,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -136,9 +117,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -147,7 +126,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get collection path @@ -169,7 +153,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // Convert the document data from JSON format to Firestore format // The client is passed to handle referenceValue types - documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -181,7 +165,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Get the collection reference - collection := t.Client.Collection(collectionPath) + collection := source.FirestoreClient().Collection(collectionPath) // Add the document to the collection docRef, writeResult, err := collection.Add(ctx, documentData) @@ -221,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go index 65ad5bbf96..00dfffccd3 100644 --- a/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go +++ b/internal/tools/firestore/firestoredeletedocuments/firestoredeletedocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to delete from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path")) params := parameters.Parameters{documentPathsParameter} @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,9 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -113,7 +92,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { @@ -143,14 +127,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Create a BulkWriter to handle multiple deletions efficiently - bulkWriter := t.Client.BulkWriter(ctx) + bulkWriter := source.FirestoreClient().BulkWriter(ctx) // Keep track of jobs for each document jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths)) // Add all delete operations to the BulkWriter for i, path := range documentPaths { - docRef := t.Client.Doc(path) + docRef := source.FirestoreClient().Doc(path) job, err := bulkWriter.Delete(docRef) if err != nil { return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err) @@ -198,10 +182,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go index 6b59c13e18..9b8c253f5e 100644 --- a/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go +++ b/internal/tools/firestore/firestoregetdocuments/firestoregetdocuments.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - documentPathsParameter := parameters.NewArrayParameter(documentPathsKey, "Array of relative document paths to retrieve from Firestore (e.g., 'users/userId' or 'users/userId/posts/postId'). Note: These are relative paths, NOT absolute paths like 'projects/{project_id}/databases/{database_id}/documents/...'", parameters.NewStringParameter("item", "Relative document path")) params := parameters.Parameters{documentPathsParameter} @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,9 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -113,7 +92,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() documentPathsRaw, ok := mapParams[documentPathsKey].([]any) if !ok { @@ -145,11 +129,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // Create document references from paths docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths)) for i, path := range documentPaths { - docRefs[i] = t.Client.Doc(path) + docRefs[i] = source.FirestoreClient().Doc(path) } // Get all documents - snapshots, err := t.Client.GetAll(ctx, docRefs) + snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs) if err != nil { return nil, fmt.Errorf("failed to get documents: %w", err) } @@ -190,10 +174,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoregetrules/firestoregetrules.go b/internal/tools/firestore/firestoregetrules/firestoregetrules.go index 0cb37b9801..b05f6ff878 100644 --- a/internal/tools/firestore/firestoregetrules/firestoregetrules.go +++ b/internal/tools/firestore/firestoregetrules/firestoregetrules.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" @@ -48,11 +47,6 @@ type compatibleSource interface { GetDatabaseId() string } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // No parameters needed for this tool params := parameters.Parameters{} @@ -90,9 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - RulesClient: s.FirebaseRulesClient(), - ProjectId: s.GetProjectId(), - DatabaseId: s.GetDatabaseId(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -104,11 +83,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - RulesClient *firebaserules.Service - ProjectId string - DatabaseId string + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,20 +92,25 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Get the latest release for Firestore - releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", t.ProjectId, t.DatabaseId) - release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do() + releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId()) + release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to get latest Firestore release: %w", err) } if release.RulesetName == "" { - return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", t.ProjectId, t.DatabaseId) + return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId()) } // Get the ruleset content - ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do() + ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to get ruleset content: %w", err) } @@ -158,10 +138,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go index 92fe1c4bfd..af3df39dfa 100644 --- a/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go +++ b/internal/tools/firestore/firestorelistcollections/firestorelistcollections.go @@ -21,7 +21,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -48,11 +47,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - emptyString := "" parentPathParameter := parameters.NewStringParameterWithDefault(parentPathKey, emptyString, "Relative parent document path to list subcollections from (e.g., 'users/userId'). If not provided, lists root collections. Note: This is a relative path, NOT an absolute path like 'projects/{project_id}/databases/{database_id}/documents/...'") params := parameters.Parameters{parentPathParameter} @@ -91,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -103,9 +84,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -114,11 +93,15 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() var collectionRefs []*firestoreapi.CollectionRef - var err error // Check if parentPath is provided parentPath, hasParent := mapParams[parentPathKey].(string) @@ -130,14 +113,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // List subcollections of the specified document - docRef := t.Client.Doc(parentPath) + docRef := source.FirestoreClient().Doc(parentPath) collectionRefs, err = docRef.Collections(ctx).GetAll() if err != nil { return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err) } } else { // List root collections - collectionRefs, err = t.Client.Collections(ctx).GetAll() + collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll() if err != nil { return nil, fmt.Errorf("failed to list root collections: %w", err) } @@ -177,10 +160,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorequery/firestorequery.go b/internal/tools/firestore/firestorequery/firestorequery.go index a7319cbf6b..9434e57171 100644 --- a/internal/tools/firestore/firestorequery/firestorequery.go +++ b/internal/tools/firestore/firestorequery/firestorequery.go @@ -24,7 +24,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -52,12 +51,9 @@ var validOperators = map[string]bool{ // Error messages const ( - errFilterParseFailed = "failed to parse filters: %w" - errQueryExecutionFailed = "failed to execute query: %w" - errTemplateParseFailed = "failed to parse template: %w" - errTemplateExecFailed = "failed to execute template: %w" - errLimitParseFailed = "failed to parse limit value '%s': %w" - errSelectFieldParseFailed = "failed to parse select field: %w" + errFilterParseFailed = "failed to parse filters: %w" + errQueryExecutionFailed = "failed to execute query: %w" + errLimitParseFailed = "failed to parse limit value '%s': %w" ) func init() { @@ -79,11 +75,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - // Config represents the configuration for the Firestore query tool type Config struct { Name string `yaml:"name" validate:"required"` @@ -114,18 +105,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance from the configuration func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Set default limit if not specified if cfg.Limit == "" { cfg.Limit = fmt.Sprintf("%d", defaultLimit) @@ -137,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -200,7 +178,12 @@ type QueryResponse struct { } // Invoke executes the Firestore query based on the provided parameters -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Process collection path with template substitution @@ -210,7 +193,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Build the query - query, err := t.buildQuery(collectionPath, paramsMap) + query, err := t.buildQuery(source, collectionPath, paramsMap) if err != nil { return nil, err } @@ -220,8 +203,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // buildQuery constructs the Firestore query from parameters -func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firestoreapi.Query, error) { - collection := t.Client.Collection(collectionPath) +func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) { + collection := source.FirestoreClient().Collection(collectionPath) query := collection.Query // Process and apply filters if template is provided @@ -239,7 +222,7 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto } // Convert simplified filter to Firestore filter - if filter := t.convertToFirestoreFilter(simplifiedFilter); filter != nil { + if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil { query = query.WhereEntity(filter) } } @@ -280,12 +263,12 @@ func (t Tool) buildQuery(collectionPath string, params map[string]any) (*firesto } // convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter -func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.EntityFilter { +func (t Tool) convertToFirestoreFilter(source compatibleSource, filter SimplifiedFilter) firestoreapi.EntityFilter { // Handle AND filters if len(filter.And) > 0 { filters := make([]firestoreapi.EntityFilter, 0, len(filter.And)) for _, f := range filter.And { - if converted := t.convertToFirestoreFilter(f); converted != nil { + if converted := t.convertToFirestoreFilter(source, f); converted != nil { filters = append(filters, converted) } } @@ -299,7 +282,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent if len(filter.Or) > 0 { filters := make([]firestoreapi.EntityFilter, 0, len(filter.Or)) for _, f := range filter.Or { - if converted := t.convertToFirestoreFilter(f); converted != nil { + if converted := t.convertToFirestoreFilter(source, f); converted != nil { filters = append(filters, converted) } } @@ -313,7 +296,7 @@ func (t Tool) convertToFirestoreFilter(filter SimplifiedFilter) firestoreapi.Ent if filter.Field != "" && filter.Op != "" && filter.Value != nil { if validOperators[filter.Op] { // Convert the value using the Firestore native JSON converter - convertedValue, err := util.JSONToFirestoreValue(filter.Value, t.Client) + convertedValue, err := util.JSONToFirestoreValue(filter.Value, source.FirestoreClient()) if err != nil { // If conversion fails, use the original value convertedValue = filter.Value @@ -525,10 +508,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go index a16e3c16eb..9601ecc099 100644 --- a/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go +++ b/internal/tools/firestore/firestorequerycollection/firestorequerycollection.go @@ -23,7 +23,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -92,11 +91,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - // Config represents the configuration for the Firestore query collection tool type Config struct { Name string `yaml:"name" validate:"required"` @@ -116,18 +110,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance from the configuration func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters params := createParameters() @@ -137,7 +119,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -199,9 +180,7 @@ var _ tools.Tool = Tool{} // Tool represents the Firestore query collection tool type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -265,7 +244,12 @@ type QueryResponse struct { } // Invoke executes the Firestore query based on the provided parameters -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Parse parameters queryParams, err := t.parseQueryParameters(params) if err != nil { @@ -273,7 +257,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Build the query - query, err := t.buildQuery(queryParams) + query, err := t.buildQuery(source, queryParams) if err != nil { return nil, err } @@ -396,8 +380,8 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) { } // buildQuery constructs the Firestore query from parameters -func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) { - collection := t.Client.Collection(params.CollectionPath) +func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) { + collection := source.FirestoreClient().Collection(params.CollectionPath) query := collection.Query // Apply filters @@ -531,10 +515,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go index d26a7a7895..d08fdb9458 100644 --- a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go +++ b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument.go @@ -22,7 +22,6 @@ import ( firestoreapi "cloud.google.com/go/firestore" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/firestore/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -52,11 +51,6 @@ type compatibleSource interface { FirestoreClient() *firestoreapi.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +67,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters documentPathParameter := parameters.NewStringParameter( documentPathKey, @@ -134,7 +116,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.FirestoreClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -146,9 +127,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Client *firestoreapi.Client + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -157,7 +136,12 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get document path @@ -200,7 +184,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Get the document reference - docRef := t.Client.Doc(documentPath) + docRef := source.FirestoreClient().Doc(documentPath) // Prepare update data var writeResult *firestoreapi.WriteResult @@ -211,7 +195,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT updates := make([]firestoreapi.Update, 0, len(updatePaths)) // Convert document data without delete markers - dataMap, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -239,7 +223,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT writeResult, writeErr = docRef.Update(ctx, updates) } else { // Update all fields in the document data (merge) - documentData, err := util.JSONToFirestoreValue(documentDataRaw, t.Client) + documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient()) if err != nil { return nil, fmt.Errorf("failed to convert document data: %w", err) } @@ -314,10 +298,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go index de2e3be40f..3311aeb86e 100644 --- a/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go +++ b/internal/tools/firestore/firestoreupdatedocument/firestoreupdatedocument_test.go @@ -132,32 +132,6 @@ func TestConfig_Initialize(t *testing.T) { }, wantErr: false, }, - { - name: "source not found", - config: Config{ - Name: "test-update-document", - Kind: "firestore-update-document", - Source: "missing-source", - Description: "Update a document", - }, - sources: map[string]sources.Source{}, - wantErr: true, - errMsg: "no source named \"missing-source\" configured", - }, - { - name: "incompatible source", - config: Config{ - Name: "test-update-document", - Kind: "firestore-update-document", - Source: "wrong-source", - Description: "Update a document", - }, - sources: map[string]sources.Source{ - "wrong-source": &mockIncompatibleSource{}, - }, - wantErr: true, - errMsg: "invalid source for \"firestore-update-document\" tool", - }, } for _, tt := range tests { @@ -464,14 +438,3 @@ func TestGetFieldValue(t *testing.T) { }) } } - -// mockIncompatibleSource is a mock source that doesn't implement compatibleSource -type mockIncompatibleSource struct{} - -func (m *mockIncompatibleSource) SourceKind() string { - return "mock" -} - -func (m *mockIncompatibleSource) ToConfig() sources.SourceConfig { - return nil -} diff --git a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go index 21bf1df9e4..69cbee4aa4 100644 --- a/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go +++ b/internal/tools/firestore/firestorevalidaterules/firestorevalidaterules.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/firebaserules/v1" @@ -53,11 +52,6 @@ type compatibleSource interface { GetProjectId() string } -// validate compatible sources are still compatible -var _ compatibleSource = &firestoreds.Source{} - -var compatibleSources = [...]string{firestoreds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Create parameters params := createParameters() mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -94,8 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - RulesClient: s.FirebaseRulesClient(), - ProjectId: s.GetProjectId(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -117,10 +97,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - RulesClient *firebaserules.Service - ProjectId string + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -153,12 +130,17 @@ type ValidationResult struct { RawIssues []Issue `json:"rawIssues,omitempty"` } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() // Get source parameter - source, ok := mapParams[sourceKey].(string) - if !ok || source == "" { + sourceParam, ok := mapParams[sourceKey].(string) + if !ok || sourceParam == "" { return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey) } @@ -168,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Files: []*firebaserules.File{ { Name: "firestore.rules", - Content: source, + Content: sourceParam, }, }, }, @@ -179,14 +161,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Call the test API - projectName := fmt.Sprintf("projects/%s", t.ProjectId) - response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do() + projectName := fmt.Sprintf("projects/%s", source.GetProjectId()) + response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do() if err != nil { return nil, fmt.Errorf("failed to validate rules: %w", err) } // Process the response - result := t.processValidationResponse(response, source) + result := t.processValidationResponse(response, sourceParam) return result, nil } @@ -287,10 +269,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/http/http.go b/internal/tools/http/http.go index e7b463e542..9e838b8b73 100644 --- a/internal/tools/http/http.go +++ b/internal/tools/http/http.go @@ -29,7 +29,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + HttpDefaultHeaders() map[string]string + HttpBaseURL() string + HttpQueryParams() map[string]string + Client() *http.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,7 +87,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) } // verify the source is compatible - s, ok := rawS.(*httpsrc.Source) + s, ok := rawS.(compatibleSource) if !ok { return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind) } @@ -89,7 +95,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Combine Source and Tool headers. // In case of conflict, Tool header overrides Source header combinedHeaders := make(map[string]string) - maps.Copy(combinedHeaders, s.DefaultHeaders) + maps.Copy(combinedHeaders, s.HttpDefaultHeaders()) maps.Copy(combinedHeaders, cfg.Headers) // Create a slice for all parameters @@ -113,14 +119,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - Config: cfg, - BaseURL: s.BaseURL, - Headers: combinedHeaders, - DefaultQueryParams: s.QueryParams, - Client: s.Client, - AllParams: allParameters, - manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Headers: combinedHeaders, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, }, nil } @@ -129,12 +132,8 @@ var _ tools.Tool = Tool{} type Tool struct { Config - BaseURL string `yaml:"baseURL"` - Headers map[string]string `yaml:"headers"` - DefaultQueryParams map[string]string `yaml:"defaultQueryParams"` - AllParams parameters.Parameters `yaml:"allParams"` - - Client *http.Client + Headers map[string]string `yaml:"headers"` + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -228,7 +227,12 @@ func getHeaders(headerParams parameters.Parameters, defaultHeaders map[string]st return allHeaders, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Calculate request body @@ -238,7 +242,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Calculate URL - urlString, err := getURL(t.BaseURL, t.Path, t.PathParams, t.QueryParams, t.DefaultQueryParams, paramsMap) + urlString, err := getURL(source.HttpBaseURL(), t.Path, t.PathParams, t.QueryParams, source.HttpQueryParams(), paramsMap) if err != nil { return nil, fmt.Errorf("error populating path parameters: %s", err) } @@ -256,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Make request and fetch response - resp, err := t.Client.Do(req) + resp, err := source.Client().Do(req) if err != nil { return nil, fmt.Errorf("error making HTTP request: %s", err) } @@ -295,10 +299,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go index 0ad1961198..8c2417157b 100644 --- a/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go +++ b/internal/tools/looker/lookeradddashboardelement/lookeradddashboardelement.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this tile will exist") @@ -86,17 +80,31 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) "", ) params = append(params, vizParameter) + dashFilters := parameters.NewArrayParameterWithRequired("dashboard_filters", + `An array of dashboard filters like [{"dashboard_filter_name": "name", "field": "view_name.field_name"}, ...]`, + false, + parameters.NewMapParameterWithDefault("dashboard_filter", + map[string]any{}, + `A dashboard filter like {"dashboard_filter_name": "name", "field": "view_name.field_name"}`, + "", + ), + ) + params = append(params, dashFilters) - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -111,13 +119,9 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { @@ -129,12 +133,19 @@ var ( visType string = "vis" ) -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } + logger.DebugContext(ctx, "params = ", params) + wq, err := lookercommon.ProcessQueryArgs(ctx, params) if err != nil { return nil, fmt.Errorf("error building query request: %w", err) @@ -147,23 +158,64 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT visConfig := paramsMap["vis_config"].(map[string]any) wq.VisConfig = &visConfig - qrespFields := "id" - - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - qresp, err := sdk.CreateQuery(*wq, qrespFields, t.ApiSettings) + qresp, err := sdk.CreateQuery(*wq, "id", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create query request: %w", err) } + dashFilters := []any{} + if v, ok := paramsMap["dashboard_filters"]; ok { + if v != nil { + dashFilters = paramsMap["dashboard_filters"].([]any) + } + } + + var filterables []v4.ResultMakerFilterables + for _, m := range dashFilters { + f := m.(map[string]any) + name, ok := f["dashboard_filter_name"].(string) + if !ok { + return nil, fmt.Errorf("error processing dashboard filter: %w", err) + } + field, ok := f["field"].(string) + if !ok { + return nil, fmt.Errorf("error processing dashboard filter: %w", err) + } + listener := v4.ResultMakerFilterablesListen{ + DashboardFilterName: &name, + Field: &field, + } + listeners := []v4.ResultMakerFilterablesListen{listener} + + filter := v4.ResultMakerFilterables{ + Listen: &listeners, + } + + filterables = append(filterables, filter) + } + + if len(filterables) == 0 { + filterables = nil + } + + wrm := v4.WriteResultMakerWithIdVisConfigAndDynamicFields{ + Query: wq, + VisConfig: &visConfig, + Filterables: &filterables, + } wde := v4.WriteDashboardElement{ DashboardId: &dashboard_id, Title: &title, + ResultMaker: &wrm, + Query: wq, QueryId: qresp.Id, } + switch len(visConfig) { case 0: wde.Type = &dataType @@ -178,7 +230,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Fields: &fields, } - resp, err := sdk.CreateDashboardElement(req, t.ApiSettings) + resp, err := sdk.CreateDashboardElement(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create dashboard element request: %w", err) } @@ -203,14 +255,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go new file mode 100644 index 0000000000..bc01526aaa --- /dev/null +++ b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter.go @@ -0,0 +1,242 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package lookeradddashboardfilter + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + + "github.com/looker-open-source/sdk-codegen/go/rtl" + v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4" +) + +const kind string = "looker-add-dashboard-filter" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + params := parameters.Parameters{} + + dashIdParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard where this filter will exist") + params = append(params, dashIdParameter) + nameParameter := parameters.NewStringParameter("name", "The name of the Dashboard Filter") + params = append(params, nameParameter) + titleParameter := parameters.NewStringParameter("title", "The title of the Dashboard Filter") + params = append(params, titleParameter) + filterTypeParameter := parameters.NewStringParameterWithDefault("filter_type", "field_filter", "The filter_type of the Dashboard Filter: date_filter, number_filter, string_filter, field_filter (default field_filter)") + params = append(params, filterTypeParameter) + defaultParameter := parameters.NewStringParameterWithRequired("default_value", "The default_value of the Dashboard Filter (optional)", false) + params = append(params, defaultParameter) + modelParameter := parameters.NewStringParameterWithRequired("model", "The model of a field type Dashboard Filter (required if type field)", false) + params = append(params, modelParameter) + exploreParameter := parameters.NewStringParameterWithRequired("explore", "The explore of a field type Dashboard Filter (required if type field)", false) + params = append(params, exploreParameter) + dimensionParameter := parameters.NewStringParameterWithRequired("dimension", "The dimension of a field type Dashboard Filter (required if type field)", false) + params = append(params, dimensionParameter) + multiValueParameter := parameters.NewBooleanParameterWithDefault("allow_multiple_values", true, "The Dashboard Filter should allow multiple values (default true)") + params = append(params, multiValueParameter) + requiredParameter := parameters.NewBooleanParameterWithDefault("required", false, "The Dashboard Filter is required to run dashboard (default false)") + params = append(params, requiredParameter) + + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) + + // finish tool setup + return Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: params.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + } + logger.DebugContext(ctx, "params = ", params) + + paramsMap := params.AsMap() + dashboard_id := paramsMap["dashboard_id"].(string) + name := paramsMap["name"].(string) + title := paramsMap["title"].(string) + filterType := paramsMap["filter_type"].(string) + switch filterType { + case "date_filter": + case "number_filter": + case "string_filter": + case "field_filter": + default: + return nil, fmt.Errorf("invalid filter type: %s. Must be one of date_filter, number_filter, string_filter, field_filter", filterType) + } + allowMultipleValues := paramsMap["allow_multiple_values"].(bool) + required := paramsMap["required"].(bool) + + req := v4.WriteCreateDashboardFilter{ + DashboardId: dashboard_id, + Name: name, + Title: title, + Type: filterType, + AllowMultipleValues: &allowMultipleValues, + Required: &required, + } + + if v, ok := paramsMap["default_value"]; ok { + if v != nil { + defaultValue := paramsMap["default_value"].(string) + req.DefaultValue = &defaultValue + } + } + + if filterType == "field_filter" { + model, ok := paramsMap["model"].(string) + if !ok || model == "" { + return nil, fmt.Errorf("model must be specified for field_filter type") + } + explore, ok := paramsMap["explore"].(string) + if !ok || explore == "" { + return nil, fmt.Errorf("explore must be specified for field_filter type") + } + dimension, ok := paramsMap["dimension"].(string) + if !ok || dimension == "" { + return nil, fmt.Errorf("dimension must be specified for field_filter type") + } + + req.Model = &model + req.Explore = &explore + req.Dimension = &dimension + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) + if err != nil { + return nil, fmt.Errorf("error getting sdk: %w", err) + } + + resp, err := sdk.CreateDashboardFilter(req, "name", source.LookerApiSettings()) + if err != nil { + return nil, fmt.Errorf("error making create dashboard filter request: %s", err) + } + logger.DebugContext(ctx, "resp = %v", resp) + + data := make(map[string]any) + + data["result"] = fmt.Sprintf("Dashboard filter \"%s\" added to dashboard %s", *resp.Name, dashboard_id) + + return data, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil +} diff --git a/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter_test.go b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter_test.go new file mode 100644 index 0000000000..43f43dc6c6 --- /dev/null +++ b/internal/tools/looker/lookeradddashboardfilter/lookeradddashboardfilter_test.go @@ -0,0 +1,116 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lookeradddashboardfilter_test + +import ( + "strings" + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + lkr "github.com/googleapis/genai-toolbox/internal/tools/looker/lookeradddashboardfilter" +) + +func TestParseFromYamlLookerAddDashboardFilter(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: looker-add-dashboard-filter + source: my-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": lkr.Config{ + Name: "example_tool", + Kind: "looker-add-dashboard-filter", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} + +func TestFailParseFromYamlLookerAddDashboardFilter(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "Invalid method", + in: ` + tools: + example_tool: + kind: looker-add-dashboard-filter + source: my-instance + method: GOT + description: some description + `, + err: "unable to parse tool \"example_tool\" as kind \"looker-add-dashboard-filter\": [4:1] unknown field \"method\"\n 1 | authRequired: []\n 2 | description: some description\n 3 | kind: looker-add-dashboard-filter\n> 4 | method: GOT\n ^\n 5 | source: my-instance", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if !strings.Contains(errStr, tc.err) { + t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err) + } + }) + } + +} diff --git a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go index 35b90aed8b..ba09f4b6a6 100644 --- a/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go +++ b/internal/tools/looker/lookerconversationalanalytics/lookerconversationalanalytics.go @@ -26,7 +26,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookerds "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -56,12 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - GetApiSettings() *rtl.ApiSettings GoogleCloudTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error) GoogleCloudProject() string GoogleCloudLocation() string UseClientAuthorization() bool GetAuthTokenHeaderName() string + LookerApiSettings() *rtl.ApiSettings } // Structs for building the JSON payload @@ -124,11 +123,6 @@ type CAPayload struct { ClientIdEnum string `json:"clientIdEnum"` } -// validate compatible sources are still compatible -var _ compatibleSource = &lookerds.Source{} - -var compatibleSources = [...]string{lookerds.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -155,7 +149,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // verify the source is compatible s, ok := rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } if s.GoogleCloudProject() == "" { @@ -177,7 +171,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) params := parameters.Parameters{userQueryParameter, exploreRefsParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // Get cloud-platform token source for Gemini Data Analytics API during initialization ctx := context.Background() @@ -188,16 +190,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ - Config: cfg, - ApiSettings: s.GetApiSettings(), - Project: s.GoogleCloudProject(), - Location: s.GoogleCloudLocation(), - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - TokenSource: ts, - manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, - mcpManifest: mcpManifest, + Config: cfg, + Parameters: params, + TokenSource: ts, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, } return t, nil } @@ -207,24 +204,23 @@ var _ tools.Tool = Tool{} type Tool struct { Config - ApiSettings *rtl.ApiSettings - UseClientOAuth bool `yaml:"useClientOAuth"` - AuthTokenHeaderName string - Parameters parameters.Parameters `yaml:"parameters"` - Project string - Location string - TokenSource oauth2.TokenSource - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + TokenSource oauth2.TokenSource + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + var tokenStr string - var err error // Get credentials for the API call // Use cloud-platform token source for Gemini Data Analytics API @@ -245,16 +241,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ler := make([]LookerExploreReference, 0) for _, er := range exploreReferences { ler = append(ler, LookerExploreReference{ - LookerInstanceUri: t.ApiSettings.BaseUrl, + LookerInstanceUri: source.LookerApiSettings().BaseUrl, LookmlModel: er.(map[string]any)["model"].(string), Explore: er.(map[string]any)["explore"].(string), }) } oauth_creds := OAuthCredentials{} - if t.UseClientOAuth { + if source.UseClientAuthorization() { oauth_creds.Token = TokenBased{AccessToken: string(accessToken)} } else { - oauth_creds.Secret = SecretBased{ClientId: t.ApiSettings.ClientId, ClientSecret: t.ApiSettings.ClientSecret} + oauth_creds.Secret = SecretBased{ClientId: source.LookerApiSettings().ClientId, ClientSecret: source.LookerApiSettings().ClientSecret} } lers := LookerExploreReferences{ @@ -265,8 +261,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Construct URL, headers, and payload - projectID := t.Project - location := t.Location + projectID := source.GoogleCloudProject() + location := source.GoogleCloudLocation() caURL := fmt.Sprintf("https://geminidataanalytics.googleapis.com/v1beta/projects/%s/locations/%s:chat", url.PathEscape(projectID), url.PathEscape(location)) headers := map[string]string{ @@ -307,12 +303,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // StreamMessage represents a single message object from the streaming API response. @@ -555,6 +555,10 @@ func appendMessage(messages []map[string]any, newMessage map[string]any) []map[s return append(messages, newMessage) } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go index 70c9f05ab6..ddf53b94f4 100644 --- a/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go +++ b/internal/tools/looker/lookercreateprojectfile/lookercreateprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,33 +67,25 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file") params := parameters.Parameters{projectIdParameter, filePathParameter, fileContentParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -102,21 +100,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -140,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Content: fileContent, } - err = lookercommon.CreateProjectFile(sdk, projectId, req, t.ApiSettings) + err = lookercommon.CreateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create_project_file request: %s", err) } @@ -164,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go index 76b7a0f6e5..5c20c95635 100644 --- a/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go +++ b/internal/tools/looker/lookerdeleteprojectfile/lookerdeleteprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,32 +67,26 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") params := parameters.Parameters{projectIdParameter, filePathParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + destructiveHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + DestructiveHint: &destructiveHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -101,21 +101,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -130,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) } - err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, t.ApiSettings) + err = lookercommon.DeleteProjectFile(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making delete_project_file request: %s", err) } @@ -154,14 +155,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerdevmode/lookerdevmode.go b/internal/tools/looker/lookerdevmode/lookerdevmode.go index dcf7735413..d33ed9c457 100644 --- a/internal/tools/looker/lookerdevmode/lookerdevmode.go +++ b/internal/tools/looker/lookerdevmode/lookerdevmode.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,38 +68,29 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - devModeParameter := parameters.NewBooleanParameterWithDefault("devMode", true, "Whether to set Dev Mode.") params := parameters.Parameters{devModeParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenExplores: s.ShowHiddenExplores, + mcpManifest: mcpManifest, }, nil } @@ -102,21 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenExplores bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -127,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -140,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT req := v4.WriteApiSession{ WorkspaceId: &devModeString, } - resp, err := sdk.UpdateSession(req, t.ApiSettings) + resp, err := sdk.UpdateSession(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error setting/resetting dev mode: %w", err) } @@ -161,14 +158,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go index 20d92af198..8dbc4a1557 100644 --- a/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go +++ b/internal/tools/looker/lookergenerateembedurl/lookergenerateembedurl.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerSessionLength() int64 +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - typeParameter := parameters.NewStringParameterWithDefault("type", "", "Type of Looker content to embed (ie. dashboards, looks, query-visualization)") idParameter := parameters.NewStringParameterWithDefault("id", "", "The ID of the content to embed.") params := parameters.Parameters{ @@ -82,23 +77,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) idParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - SessionLength: s.SessionLength, + mcpManifest: mcpManifest, }, nil } @@ -107,22 +105,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - AuthRequired []string `yaml:"authRequired"` - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest - SessionLength int64 + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -139,16 +136,16 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT contentId_ptr = nil } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } forceLogoutLogin := true - + sessionLength := source.LookerSessionLength() req := v4.EmbedParams{ - TargetUrl: fmt.Sprintf("%s/embed/%s/%s", t.ApiSettings.BaseUrl, *embedType_ptr, *contentId_ptr), - SessionLength: &t.SessionLength, + TargetUrl: fmt.Sprintf("%s/embed/%s/%s", source.LookerApiSettings().BaseUrl, *embedType_ptr, *contentId_ptr), + SessionLength: &sessionLength, ForceLogoutLogin: &forceLogoutLogin, } logger.ErrorContext(ctx, "Making request %v", req) @@ -173,14 +170,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go index e73fcc6ec4..c637b92260 100644 --- a/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go +++ b/internal/tools/looker/lookergetconnectiondatabases/lookergetconnectiondatabases.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,31 +67,23 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the databases.") params := parameters.Parameters{connParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -100,31 +98,32 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.ConnectionDatabases(conn, t.ApiSettings) + resp, err := sdk.ConnectionDatabases(conn, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_databases request: %s", err) } @@ -145,14 +144,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnections/lookergetconnections.go b/internal/tools/looker/lookergetconnections/lookergetconnections.go index b4e307418b..75b4622a56 100644 --- a/internal/tools/looker/lookergetconnections/lookergetconnections.go +++ b/internal/tools/looker/lookergetconnections/lookergetconnections.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,30 +68,22 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -100,30 +98,31 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.AllConnections("name, dialect(name), database, schema", t.ApiSettings) + resp, err := sdk.AllConnections("name, dialect(name), database, schema", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connections request: %s", err) } @@ -139,7 +138,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if v.Schema != nil { vMap["schema"] = *v.Schema } - conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", t.ApiSettings) + conn, err := sdk.ConnectionFeatures(*v.Name, "multiple_databases", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_features request: %s", err) } @@ -164,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go index dc471bc9b0..6ceac7a205 100644 --- a/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go +++ b/internal/tools/looker/lookergetconnectionschemas/lookergetconnectionschemas.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,32 +67,24 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the schemas.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) params := parameters.Parameters{connParameter, dbParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -101,20 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + mapParams := params.AsMap() conn, ok := mapParams["conn"].(string) if !ok { @@ -122,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } db, _ := mapParams["db"].(string) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -132,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if db != "" { req.Database = &db } - resp, err := sdk.ConnectionSchemas(req, t.ApiSettings) + resp, err := sdk.ConnectionSchemas(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_schemas request: %s", err) } @@ -151,14 +150,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go index 0d881919e6..4b1991cacf 100644 --- a/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go +++ b/internal/tools/looker/lookergetconnectiontablecolumns/lookergetconnectiontablecolumns.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,34 +68,26 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the tables.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) schemaParameter := parameters.NewStringParameter("schema", "The schema containing the tables.") tablesParameter := parameters.NewStringParameter("tables", "A comma separated list of tables containing the columns.") params := parameters.Parameters{connParameter, dbParameter, schemaParameter, tablesParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -104,20 +102,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -149,7 +148,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if db != "" { req.Database = &db } - resp, err := sdk.ConnectionColumns(req, t.ApiSettings) + resp, err := sdk.ConnectionColumns(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_table_columns request: %s", err) } @@ -188,14 +187,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go index 853296c32d..1fd9df6515 100644 --- a/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go +++ b/internal/tools/looker/lookergetconnectiontables/lookergetconnectiontables.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,33 +68,25 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - connParameter := parameters.NewStringParameter("conn", "The connection containing the tables.") dbParameter := parameters.NewStringParameterWithRequired("db", "The optional database to search", false) schemaParameter := parameters.NewStringParameter("schema", "The schema containing the tables.") params := parameters.Parameters{connParameter, dbParameter, schemaParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -103,20 +101,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -132,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -143,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if db != "" { req.Database = &db } - resp, err := sdk.ConnectionTables(req, t.ApiSettings) + resp, err := sdk.ConnectionTables(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_connection_tables request: %s", err) } @@ -179,14 +178,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go index 4438468adb..6ef5be2f45 100644 --- a/internal/tools/looker/lookergetdashboards/lookergetdashboards.go +++ b/internal/tools/looker/lookergetdashboards/lookergetdashboards.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - titleParameter := parameters.NewStringParameterWithDefault("title", "", "The title of the dashboard.") descParameter := parameters.NewStringParameterWithDefault("desc", "", "The description of the dashboard.") limitParameter := parameters.NewIntParameterWithDefault("limit", 100, "The number of dashboards to fetch. Default 100") @@ -85,16 +79,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) offsetParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -109,20 +107,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT limit := int64(paramsMap["limit"].(int)) offset := int64(paramsMap["offset"].(int)) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -152,7 +151,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Offset: &offset, } logger.ErrorContext(ctx, "Making request %v", req) - resp, err := sdk.SearchDashboards(req, t.ApiSettings) + resp, err := sdk.SearchDashboards(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_dashboards request: %s", err) } @@ -190,14 +189,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go index 58f1f833b3..92c795dfb2 100644 --- a/internal/tools/looker/lookergetdimensions/lookergetdimensions.go +++ b/internal/tools/looker/lookergetdimensions/lookergetdimensions.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,37 +69,28 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -101,21 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -125,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error processing model or explore: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -135,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_dimensions request: %w", err) } @@ -144,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error processing get_dimensions response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Dimensions, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_dimensions response: %w", err) } @@ -165,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetexplores/lookergetexplores.go b/internal/tools/looker/lookergetexplores/lookergetexplores.go index ebc10ba765..75eaf9485a 100644 --- a/internal/tools/looker/lookergetexplores/lookergetexplores.go +++ b/internal/tools/looker/lookergetexplores/lookergetexplores.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenExplores() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,38 +69,29 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - modelParameter := parameters.NewStringParameter("model", "The model containing the explores.") params := parameters.Parameters{modelParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenExplores: s.ShowHiddenExplores, + mcpManifest: mcpManifest, }, nil } @@ -102,21 +100,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenExplores bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -127,11 +125,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"]) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", t.ApiSettings) + resp, err := sdk.LookmlModel(model, "explores(name,description,label,group_label,hidden)", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_explores request: %s", err) } @@ -139,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT var data []any for _, v := range *resp.Explores { logger.DebugContext(ctx, "Got response element of %v\n", v) - if !t.ShowHiddenExplores && v.Hidden != nil && *v.Hidden { + if !source.LookerShowHiddenExplores() && v.Hidden != nil && *v.Hidden { continue } vMap := make(map[string]any) @@ -175,14 +173,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetfilters/lookergetfilters.go b/internal/tools/looker/lookergetfilters/lookergetfilters.go index 8204f004ca..413874886b 100644 --- a/internal/tools/looker/lookergetfilters/lookergetfilters.go +++ b/internal/tools/looker/lookergetfilters/lookergetfilters.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,37 +69,28 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -101,21 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -126,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } fields := lookercommon.FiltersFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -135,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_filters request: %w", err) } @@ -144,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error processing get_filters response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Filters, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Filters, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_filters response: %w", err) } @@ -165,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetlooks/lookergetlooks.go b/internal/tools/looker/lookergetlooks/lookergetlooks.go index 33b4135f09..b52bc059b4 100644 --- a/internal/tools/looker/lookergetlooks/lookergetlooks.go +++ b/internal/tools/looker/lookergetlooks/lookergetlooks.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - titleParameter := parameters.NewStringParameterWithDefault("title", "", "The title of the look.") descParameter := parameters.NewStringParameterWithDefault("desc", "", "The description of the look.") limitParameter := parameters.NewIntParameterWithDefault("limit", 100, "The number of looks to fetch. Default 100") @@ -85,16 +79,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) offsetParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -109,20 +107,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT limit := int64(paramsMap["limit"].(int)) offset := int64(paramsMap["offset"].(int)) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -151,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Limit: &limit, Offset: &offset, } - resp, err := sdk.SearchLooks(req, t.ApiSettings) + resp, err := sdk.SearchLooks(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_looks request: %s", err) } @@ -190,14 +189,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go index de1bef032e..56b810126b 100644 --- a/internal/tools/looker/lookergetmeasures/lookergetmeasures.go +++ b/internal/tools/looker/lookergetmeasures/lookergetmeasures.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,37 +69,28 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -101,21 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -126,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } fields := lookercommon.MeasuresFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -135,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_measures request: %w", err) } @@ -144,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error processing get_measures response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Measures, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Measures, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_measures response: %w", err) } @@ -165,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetmodels/lookergetmodels.go b/internal/tools/looker/lookergetmodels/lookergetmodels.go index cfa027384d..5c4f70f6b1 100644 --- a/internal/tools/looker/lookergetmodels/lookergetmodels.go +++ b/internal/tools/looker/lookergetmodels/lookergetmodels.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenModels() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,37 +69,28 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenModels: s.ShowHiddenModels, + mcpManifest: mcpManifest, }, nil } @@ -101,31 +99,31 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenModels bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } excludeEmpty := false - excludeHidden := !t.ShowHiddenModels + excludeHidden := !source.LookerShowHiddenModels() includeInternal := true - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -134,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ExcludeHidden: &excludeHidden, IncludeInternal: &includeInternal, } - resp, err := sdk.AllLookmlModels(req, t.ApiSettings) + resp, err := sdk.AllLookmlModels(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_models request: %s", err) } @@ -167,14 +165,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetparameters/lookergetparameters.go b/internal/tools/looker/lookergetparameters/lookergetparameters.go index d84ec07726..2333cfb892 100644 --- a/internal/tools/looker/lookergetparameters/lookergetparameters.go +++ b/internal/tools/looker/lookergetparameters/lookergetparameters.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings + LookerShowHiddenFields() bool +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,37 +69,28 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetFieldParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired, }, - mcpManifest: mcpManifest, - ShowHiddenFields: s.ShowHiddenFields, + mcpManifest: mcpManifest, }, nil } @@ -101,21 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest - ShowHiddenFields bool + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -126,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } fields := lookercommon.ParametersFields - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -135,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT ExploreName: *explore, Fields: &fields, } - resp, err := sdk.LookmlModelExplore(req, t.ApiSettings) + resp, err := sdk.LookmlModelExplore(req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_parameters request: %w", err) } @@ -144,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error processing get_parameters response: %w", err) } - data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Parameters, t.ShowHiddenFields) + data, err := lookercommon.ExtractLookerFieldProperties(ctx, resp.Fields.Parameters, source.LookerShowHiddenFields()) if err != nil { return nil, fmt.Errorf("error extracting get_parameters response: %w", err) } @@ -165,14 +163,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go index 6d91d09b14..6d3fd015d3 100644 --- a/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go +++ b/internal/tools/looker/lookergetprojectfile/lookergetprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,32 +68,24 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") params := parameters.Parameters{projectIdParameter, filePathParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -102,26 +100,27 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -136,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'file_path' must be a string, got %T", mapParams["file_path"]) } - resp, err := lookercommon.GetProjectFileContent(sdk, projectId, filePath, t.ApiSettings) + resp, err := lookercommon.GetProjectFileContent(sdk, projectId, filePath, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_project_file request: %s", err) } @@ -161,14 +160,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go index 0705449536..78f3182246 100644 --- a/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go +++ b/internal/tools/looker/lookergetprojectfiles/lookergetprojectfiles.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,31 +68,23 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") params := parameters.Parameters{projectIdParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -101,26 +99,27 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -131,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("'project_id' must be a string, got %T", mapParams["project_id"]) } - resp, err := sdk.AllProjectFiles(projectId, "", t.ApiSettings) + resp, err := sdk.AllProjectFiles(projectId, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_project_files request: %s", err) } @@ -178,14 +177,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookergetprojects/lookergetprojects.go b/internal/tools/looker/lookergetprojects/lookergetprojects.go index 2e4a2d23a3..5756413662 100644 --- a/internal/tools/looker/lookergetprojects/lookergetprojects.go +++ b/internal/tools/looker/lookergetprojects/lookergetprojects.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,30 +68,22 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Parameters: params, - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -100,31 +98,32 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := sdk.AllProjects("id,name", t.ApiSettings) + resp, err := sdk.AllProjects("id,name", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making get_models request: %s", err) } @@ -155,14 +154,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go index a0aad8335f..0675b4dee5 100644 --- a/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go +++ b/internal/tools/looker/lookerhealthanalyze/lookerhealthanalyze.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,16 +73,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The analysis to run. Can be 'projects', 'models', or 'explores'.", true) projectParameter := parameters.NewStringParameterWithRequired("project", "The Looker project to analyze (optional).", false) modelParameter := parameters.NewStringParameterWithRequired("model", "The Looker model to analyze (optional).", false) @@ -93,15 +89,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) minQueriesParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,26 +115,27 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -203,12 +204,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -558,6 +563,10 @@ func (t *analyzeTool) explores(ctx context.Context, model, explore string) ([]ma // END LOOKER HEALTH ANALYZE CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go index e0dd560e79..45307b5011 100644 --- a/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go +++ b/internal/tools/looker/lookerhealthpulse/lookerhealthpulse.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,32 +73,26 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The health check to run. Can be either: `check_db_connections`, `check_dashboard_performance`,`check_dashboard_errors`,`check_explore_performance`,`check_schedule_failures`, or `check_legacy_features`", true) params := parameters.Parameters{ actionParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -106,32 +106,33 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } pulseTool := &pulseTool{ - ApiSettings: t.ApiSettings, + ApiSettings: source.LookerApiSettings(), SdkClient: sdk, } @@ -145,7 +146,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Action: action, } - result, err := pulseTool.RunPulse(ctx, pulseParams) + result, err := pulseTool.RunPulse(ctx, source, pulseParams) if err != nil { return nil, fmt.Errorf("error running pulse: %w", err) } @@ -167,12 +168,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -193,27 +198,27 @@ type pulseTool struct { SdkClient *v4.LookerSDK } -func (t *pulseTool) RunPulse(ctx context.Context, params PulseParams) (interface{}, error) { +func (t *pulseTool) RunPulse(ctx context.Context, source compatibleSource, params PulseParams) (interface{}, error) { switch params.Action { case "check_db_connections": - return t.checkDBConnections(ctx) + return t.checkDBConnections(ctx, source) case "check_dashboard_performance": - return t.checkDashboardPerformance(ctx) + return t.checkDashboardPerformance(ctx, source) case "check_dashboard_errors": - return t.checkDashboardErrors(ctx) + return t.checkDashboardErrors(ctx, source) case "check_explore_performance": - return t.checkExplorePerformance(ctx) + return t.checkExplorePerformance(ctx, source) case "check_schedule_failures": - return t.checkScheduleFailures(ctx) + return t.checkScheduleFailures(ctx, source) case "check_legacy_features": - return t.checkLegacyFeatures(ctx) + return t.checkLegacyFeatures(ctx, source) default: return nil, fmt.Errorf("unknown action: %s", params.Action) } } // Check DB connections and run tests -func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDBConnections(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -227,7 +232,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) "looker__ilooker": {}, } - connections, err := t.SdkClient.AllConnections("", t.ApiSettings) + connections, err := t.SdkClient.AllConnections("", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error fetching connections: %w", err) } @@ -246,7 +251,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) for _, conn := range filteredConnections { var errors []string // Test connection (simulate test_connection endpoint) - resp, err := t.SdkClient.TestConnection(*conn.Name, nil, t.ApiSettings) + resp, err := t.SdkClient.TestConnection(*conn.Name, nil, source.LookerApiSettings()) if err != nil { errors = append(errors, "API JSONDecode Error") } else { @@ -270,7 +275,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) }, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -291,7 +296,7 @@ func (t *pulseTool) checkDBConnections(ctx context.Context) (interface{}, error) return results, nil } -func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDashboardPerformance(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -312,7 +317,7 @@ func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, Sorts: &[]string{"query.count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -323,7 +328,7 @@ func (t *pulseTool) checkDashboardPerformance(ctx context.Context) (interface{}, return dashboards, nil } -func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkDashboardErrors(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -344,7 +349,7 @@ func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, erro Sorts: &[]string{"history.query_run_count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -355,7 +360,7 @@ func (t *pulseTool) checkDashboardErrors(ctx context.Context) (interface{}, erro return dashboards, nil } -func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkExplorePerformance(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -374,7 +379,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e Sorts: &[]string{"history.average_runtime desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -385,7 +390,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e // Average query runtime query.Fields = &[]string{"history.average_runtime"} - rawAvg, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + rawAvg, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -400,7 +405,7 @@ func (t *pulseTool) checkExplorePerformance(ctx context.Context) (interface{}, e return explores, nil } -func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkScheduleFailures(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -419,7 +424,7 @@ func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, err Sorts: &[]string{"scheduled_job.count desc"}, Limit: &limit, } - raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", t.ApiSettings) + raw, err := lookercommon.RunInlineQuery(ctx, t.SdkClient, query, "json", source.LookerApiSettings()) if err != nil { return nil, err } @@ -430,14 +435,14 @@ func (t *pulseTool) checkScheduleFailures(ctx context.Context) (interface{}, err return schedules, nil } -func (t *pulseTool) checkLegacyFeatures(ctx context.Context) (interface{}, error) { +func (t *pulseTool) checkLegacyFeatures(ctx context.Context, source compatibleSource) (interface{}, error) { logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } logger.InfoContext(ctx, "Test 6/6: Checking for enabled legacy features") - features, err := t.SdkClient.AllLegacyFeatures(t.ApiSettings) + features, err := t.SdkClient.AllLegacyFeatures(source.LookerApiSettings()) if err != nil { if strings.Contains(err.Error(), "Unsupported in Looker (Google Cloud core)") { return []map[string]string{{"Feature": "Unsupported in Looker (Google Cloud core)"}}, nil @@ -458,6 +463,10 @@ func (t *pulseTool) checkLegacyFeatures(ctx context.Context) (interface{}, error // END LOOKER HEALTH PULSE CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go index 39bf4cfb21..d1d55a2fd0 100644 --- a/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go +++ b/internal/tools/looker/lookerhealthvacuum/lookerhealthvacuum.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -50,6 +49,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,16 +73,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - actionParameter := parameters.NewStringParameterWithRequired("action", "The vacuum action to run. Can be 'models', or 'explores'.", true) projectParameter := parameters.NewStringParameterWithDefault("project", "", "The Looker project to vacuum (optional).") modelParameter := parameters.NewStringParameterWithDefault("model", "", "The Looker model to vacuum (optional).") @@ -93,15 +89,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) minQueriesParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,21 +115,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -181,12 +182,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } // ================================================================================================================= @@ -462,6 +467,10 @@ func (t *vacuumTool) getUsedExploreFields(ctx context.Context, model, explore st // END LOOKER HEALTH VACUUM CORE LOGIC // ================================================================================================================= -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go index 961461add9..2930d6e993 100644 --- a/internal/tools/looker/lookermakedashboard/lookermakedashboard.go +++ b/internal/tools/looker/lookermakedashboard/lookermakedashboard.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := parameters.Parameters{} titleParameter := parameters.NewStringParameter("title", "The title of the Dashboard") @@ -83,16 +77,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) descParameter := parameters.NewStringParameterWithDefault("description", "", "The description of the Dashboard") params = append(params, descParameter) - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -107,32 +105,33 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) } logger.DebugContext(ctx, "params = ", params) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } mrespFields := "id,personal_folder_id" - mresp, err := sdk.Me(mrespFields, t.ApiSettings) + mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making me request: %s", err) } @@ -145,7 +144,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("user does not have a personal folder. cannot continue") } - dashs, err := sdk.FolderDashboards(*mresp.PersonalFolderId, "title", t.ApiSettings) + dashs, err := sdk.FolderDashboards(*mresp.PersonalFolderId, "title", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting existing dashboards in folder: %s", err) } @@ -164,13 +163,13 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Description: &description, FolderId: mresp.PersonalFolderId, } - resp, err := sdk.CreateDashboard(wd, t.ApiSettings) + resp, err := sdk.CreateDashboard(wd, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create dashboard request: %s", err) } logger.DebugContext(ctx, "resp = %v", resp) - setting, err := sdk.GetSetting("host_url", t.ApiSettings) + setting, err := sdk.GetSetting("host_url", source.LookerApiSettings()) if err != nil { logger.ErrorContext(ctx, "error getting settings: %s", err) } @@ -203,14 +202,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookermakelook/lookermakelook.go b/internal/tools/looker/lookermakelook/lookermakelook.go index c4180ea4f0..7244c5d6fe 100644 --- a/internal/tools/looker/lookermakelook/lookermakelook.go +++ b/internal/tools/looker/lookermakelook/lookermakelook.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() titleParameter := parameters.NewStringParameter("title", "The title of the Look") @@ -89,16 +83,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) ) params = append(params, vizParameter) - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -113,20 +111,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -137,12 +136,12 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error building query request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } mrespFields := "id,personal_folder_id" - mresp, err := sdk.Me(mrespFields, t.ApiSettings) + mresp, err := sdk.Me(mrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making me request: %s", err) } @@ -151,7 +150,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT title := paramsMap["title"].(string) description := paramsMap["description"].(string) - looks, err := sdk.FolderLooks(*mresp.PersonalFolderId, "title", t.ApiSettings) + looks, err := sdk.FolderLooks(*mresp.PersonalFolderId, "title", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting existing looks in folder: %s", err) } @@ -169,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT wq.VisConfig = &visConfig qrespFields := "id" - qresp, err := sdk.CreateQuery(*wq, qrespFields, t.ApiSettings) + qresp, err := sdk.CreateQuery(*wq, qrespFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create query request: %s", err) } @@ -181,13 +180,13 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT QueryId: qresp.Id, FolderId: mresp.PersonalFolderId, } - resp, err := sdk.CreateLook(wlwq, "", t.ApiSettings) + resp, err := sdk.CreateLook(wlwq, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making create look request: %s", err) } logger.DebugContext(ctx, "resp = %v", resp) - setting, err := sdk.GetSetting("host_url", t.ApiSettings) + setting, err := sdk.GetSetting("host_url", source.LookerApiSettings()) if err != nil { logger.ErrorContext(ctx, "error getting settings: %s", err) } @@ -220,14 +219,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerquery/lookerquery.go b/internal/tools/looker/lookerquery/lookerquery.go index 7d0a4c5209..7f37d71c76 100644 --- a/internal/tools/looker/lookerquery/lookerquery.go +++ b/internal/tools/looker/lookerquery/lookerquery.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,30 +69,22 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -101,20 +99,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -123,11 +122,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if err != nil { return nil, fmt.Errorf("error building WriteQuery request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "json", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -157,14 +156,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerquerysql/lookerquerysql.go b/internal/tools/looker/lookerquerysql/lookerquerysql.go index f3c3c0e07c..648894d8ed 100644 --- a/internal/tools/looker/lookerquerysql/lookerquerysql.go +++ b/internal/tools/looker/lookerquerysql/lookerquerysql.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,30 +68,22 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -100,20 +98,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -122,11 +121,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if err != nil { return nil, fmt.Errorf("error building query request: %w", err) } - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, wq, "sql", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -147,14 +146,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go index 61ae4904a1..f76e0014a2 100644 --- a/internal/tools/looker/lookerqueryurl/lookerqueryurl.go +++ b/internal/tools/looker/lookerqueryurl/lookerqueryurl.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -45,6 +44,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - params := lookercommon.GetQueryParameters() vizParameter := parameters.NewMapParameterWithDefault("vis_config", @@ -83,16 +77,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) ) params = append(params, vizParameter) - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -107,20 +105,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -135,12 +134,12 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT visConfig := paramsMap["vis_config"].(map[string]any) wq.VisConfig = &visConfig - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } respFields := "id,slug,share_url,expanded_share_url" - resp, err := sdk.CreateQuery(*wq, respFields, t.ApiSettings) + resp, err := sdk.CreateQuery(*wq, respFields, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making query request: %s", err) } @@ -176,14 +175,22 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil +} + func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth -} - -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go index c7ac9831b6..6a27a77e3a 100644 --- a/internal/tools/looker/lookerrundashboard/lookerrundashboard.go +++ b/internal/tools/looker/lookerrundashboard/lookerrundashboard.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -47,6 +46,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -64,34 +70,26 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - dashboardidParameter := parameters.NewStringParameter("dashboard_id", "The id of the dashboard to run.") params := parameters.Parameters{ dashboardidParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -106,20 +104,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -129,11 +128,11 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT dashboard_id := paramsMap["dashboard_id"].(string) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - dashboard, err := sdk.Dashboard(dashboard_id, "", t.ApiSettings) + dashboard, err := sdk.Dashboard(dashboard_id, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting dashboard: %w", err) } @@ -149,7 +148,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT channels := make([]<-chan map[string]any, len(*dashboard.DashboardElements)) for i, element := range *dashboard.DashboardElements { - channels[i] = tileQueryWorker(ctx, sdk, t.ApiSettings, i, element) + channels[i] = tileQueryWorker(ctx, sdk, source.LookerApiSettings(), i, element) } for resp := range merge(channels...) { @@ -173,12 +172,16 @@ func (t Tool) McpManifest() tools.McpManifest { return t.mcpManifest } -func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } func tileQueryWorker(ctx context.Context, sdk *v4.LookerSDK, options *rtl.ApiSettings, index int, element v4.DashboardElement) <-chan map[string]any { @@ -270,6 +273,10 @@ func merge(channels ...<-chan map[string]any) <-chan map[string]any { return out } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerrunlook/lookerrunlook.go b/internal/tools/looker/lookerrunlook/lookerrunlook.go index c8cdf1ecc9..9c7136b6c2 100644 --- a/internal/tools/looker/lookerrunlook/lookerrunlook.go +++ b/internal/tools/looker/lookerrunlook/lookerrunlook.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -46,6 +45,13 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - lookidParameter := parameters.NewStringParameter("look_id", "The id of the look to run.") limitParameter := parameters.NewIntParameterWithDefault("limit", 500, "The row limit. Default 500") @@ -83,16 +77,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) limitParameter, } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -107,20 +105,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("unable to get logger from ctx: %s", err) @@ -132,12 +131,12 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT limit := int64(paramsMap["limit"].(int)) limitStr := fmt.Sprintf("%d", limit) - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } - look, err := sdk.Look(look_id, "", t.ApiSettings) + look, err := sdk.Look(look_id, "", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error getting look definition: %s", err) } @@ -153,7 +152,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Limit: &limitStr, } - resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", t.ApiSettings) + resp, err := lookercommon.RunInlineQuery(ctx, sdk, &wq, "json", source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making run_look request: %s", err) } @@ -186,10 +185,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go index 957e96667c..2981f24270 100644 --- a/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go +++ b/internal/tools/looker/lookerupdateprojectfile/lookerupdateprojectfile.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - lookersrc "github.com/googleapis/genai-toolbox/internal/sources/looker" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -44,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + UseClientAuthorization() bool + GetAuthTokenHeaderName() string + LookerClient() *v4.LookerSDK + LookerApiSettings() *rtl.ApiSettings +} type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,33 +66,27 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*lookersrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `looker`", kind) - } - projectIdParameter := parameters.NewStringParameter("project_id", "The id of the project containing the files") filePathParameter := parameters.NewStringParameter("file_path", "The path of the file within the project") fileContentParameter := parameters.NewStringParameter("file_content", "The content of the file") params := parameters.Parameters{projectIdParameter, filePathParameter, fileContentParameter} - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, cfg.Annotations) + annotations := cfg.Annotations + if annotations == nil { + readOnlyHint := false + destructiveHint := true + annotations = &tools.ToolAnnotations{ + ReadOnlyHint: &readOnlyHint, + DestructiveHint: &destructiveHint, + } + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, annotations) // finish tool setup return Tool{ - Config: cfg, - Parameters: params, - UseClientOAuth: s.UseClientAuthorization(), - AuthTokenHeaderName: s.GetAuthTokenHeaderName(), - Client: s.Client, - ApiSettings: s.ApiSettings, + Config: cfg, + Parameters: params, manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -102,21 +101,22 @@ var _ tools.Tool = Tool{} type Tool struct { Config - UseClientOAuth bool - AuthTokenHeaderName string - Client *v4.LookerSDK - ApiSettings *rtl.ApiSettings - Parameters parameters.Parameters `yaml:"parameters"` - manifest tools.Manifest - mcpManifest tools.McpManifest + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sdk, err := lookercommon.GetLookerSDK(t.UseClientOAuth, t.ApiSettings, t.Client, accessToken) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken) if err != nil { return nil, fmt.Errorf("error getting sdk: %w", err) } @@ -140,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT Content: fileContent, } - err = lookercommon.UpdateProjectFile(sdk, projectId, req, t.ApiSettings) + err = lookercommon.UpdateProjectFile(sdk, projectId, req, source.LookerApiSettings()) if err != nil { return nil, fmt.Errorf("error making update_project_file request: %s", err) } @@ -168,10 +168,18 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return t.UseClientOAuth +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return false, err + } + return source.UseClientAuthorization(), nil } -func (t Tool) GetAuthTokenHeaderName() string { - return t.AuthTokenHeaderName +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return "", err + } + return source.GetAuthTokenHeaderName(), nil } diff --git a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go index b7dbcb7823..51f2952177 100644 --- a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go +++ b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { MindsDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MindsDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,9 +87,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,14 +96,19 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) } - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MindsDBPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -193,10 +177,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go index 58adaf433d..c247f4d4dc 100644 --- a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go +++ b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { MindsDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -100,7 +82,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MindsDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -112,14 +93,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -134,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT sliceParams := newParams.AsSlice() // MindsDB now supports MySQL prepared statements natively - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.MindsDBPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -203,14 +187,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go index c5ade9d127..ccf7655ca3 100644 --- a/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go +++ b/internal/tools/mongodb/mongodbaggregate/mongodbaggregate.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -45,6 +44,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +70,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.PipelineParams) @@ -96,7 +87,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -107,14 +97,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() pipelineString, err := parameters.PopulateTemplateWithJSON("MongoDBAggregatePipeline", t.PipelinePayload, paramsMap) @@ -139,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - cur, err := t.database.Collection(t.Collection).Aggregate(ctx, pipeline) + cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Aggregate(ctx, pipeline) if err != nil { return nil, err } @@ -185,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go index e089b35ee0..566113b34b 100644 --- a/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go +++ b/internal/tools/mongodb/mongodbdeletemany/mongodbdeletemany.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -66,18 +69,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams) @@ -101,7 +92,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -112,14 +102,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteManyFilter", t.FilterPayload, paramsMap) @@ -135,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, err } - res, err := t.database.Collection(t.Collection).DeleteMany(ctx, filter, opts) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteMany(ctx, filter, opts) if err != nil { return nil, err } @@ -164,14 +157,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go index 7ac903ff2f..6d16e5df70 100644 --- a/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go +++ b/internal/tools/mongodb/mongodbdeleteone/mongodbdeleteone.go @@ -19,7 +19,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -45,6 +44,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams) @@ -100,7 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -111,14 +101,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBDeleteOneFilter", t.FilterPayload, paramsMap) @@ -134,7 +127,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, err } - res, err := t.database.Collection(t.Collection).DeleteOne(ctx, filter, opts) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).DeleteOne(ctx, filter, opts) if err != nil { return nil, err } @@ -159,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbfind/mongodbfind.go b/internal/tools/mongodb/mongodbfind/mongodbfind.go index 4aeebf1652..88f3b25488 100644 --- a/internal/tools/mongodb/mongodbfind/mongodbfind.go +++ b/internal/tools/mongodb/mongodbfind/mongodbfind.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -47,6 +46,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +75,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams) @@ -111,7 +102,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -122,9 +112,7 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } @@ -168,7 +156,12 @@ func getOptions(ctx context.Context, sortParameters parameters.Parameters, proje return opts, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindFilterString", t.FilterPayload, paramsMap) @@ -188,7 +181,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, err } - cur, err := t.database.Collection(t.Collection).Find(ctx, filter, opts) + cur, err := source.MongoClient().Database(t.Database).Collection(t.Collection).Find(ctx, filter, opts) if err != nil { return nil, err } @@ -230,14 +223,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go index bba70fe976..2e01d8e644 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone.go @@ -20,7 +20,6 @@ import ( "slices" "github.com/goccy/go-yaml" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -58,8 +61,6 @@ type Config struct { FilterParams parameters.Parameters `yaml:"filterParams"` ProjectPayload string `yaml:"projectPayload"` ProjectParams parameters.Parameters `yaml:"projectParams"` - SortPayload string `yaml:"sortPayload"` - SortParams parameters.Parameters `yaml:"sortParams"` } // validate interface @@ -70,20 +71,8 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters - allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams, cfg.SortParams) + allParameters := slices.Concat(cfg.FilterParams, cfg.ProjectParams) // Verify no duplicate parameter names err := parameters.CheckDuplicateParameters(allParameters) @@ -105,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -116,42 +104,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func getOptions(sortParameters parameters.Parameters, projectPayload string, paramsMap map[string]any) (*options.FindOneOptions, error) { - opts := options.FindOne() - - sort := bson.M{} - for _, p := range sortParameters { - sort[p.GetName()] = paramsMap[p.GetName()] - } - opts = opts.SetSort(sort) - - if len(projectPayload) == 0 { - return opts, nil - } - - result, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneProjectString", projectPayload, paramsMap) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { - return nil, fmt.Errorf("error populating project payload: %s", err) + return nil, err } - var projection any - err = bson.UnmarshalExtJSON([]byte(result), false, &projection) - if err != nil { - return nil, fmt.Errorf("error unmarshalling projection: %s", err) - } - opts = opts.SetProjection(projection) - - return opts, nil -} - -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneFilterString", t.FilterPayload, paramsMap) @@ -160,9 +123,18 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("error populating filter: %s", err) } - opts, err := getOptions(t.SortParams, t.ProjectPayload, paramsMap) - if err != nil { - return nil, fmt.Errorf("error populating options: %s", err) + opts := options.FindOne() + if len(t.ProjectPayload) > 0 { + result, err := parameters.PopulateTemplateWithJSON("MongoDBFindOneProjectString", t.ProjectPayload, paramsMap) + if err != nil { + return nil, fmt.Errorf("error populating project payload: %s", err) + } + var projection any + err = bson.UnmarshalExtJSON([]byte(result), false, &projection) + if err != nil { + return nil, fmt.Errorf("error unmarshalling projection: %s", err) + } + opts = opts.SetProjection(projection) } var filter = bson.D{} @@ -171,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, err } - res := t.database.Collection(t.Collection).FindOne(ctx, filter, opts) + res := source.MongoClient().Database(t.Database).Collection(t.Collection).FindOne(ctx, filter, opts) if res.Err() != nil { return nil, res.Err() } @@ -210,14 +182,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go b/internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go index 2eccffd835..a8d5b9bfc0 100644 --- a/internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go +++ b/internal/tools/mongodb/mongodbfindone/mongodbfindone_test.go @@ -56,9 +56,6 @@ func TestParseFromYamlMongoQuery(t *testing.T) { projectPayload: | { name: 1, age: 1 } projectParams: [] - sortPayload: | - { timestamp: -1 } - sortParams: [] `, want: server.ToolConfigs{ "example_tool": mongodbfindone.Config{ @@ -81,8 +78,6 @@ func TestParseFromYamlMongoQuery(t *testing.T) { }, ProjectPayload: "{ name: 1, age: 1 }\n", ProjectParams: parameters.Parameters{}, - SortPayload: "{ timestamp: -1 }\n", - SortParams: parameters.Parameters{}, }, }, }, diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go index bf4c4512bd..f0cbf29d1d 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -54,7 +57,7 @@ type Config struct { Description string `yaml:"description" validate:"required"` Database string `yaml:"database" validate:"required"` Collection string `yaml:"collection" validate:"required"` - Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose + Canonical bool `yaml:"canonical"` } // validate interface @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - dataParam := parameters.NewStringParameterWithRequired(paramDataKey, "the JSON payload to insert, should be a JSON array of documents", true) allParameters := parameters.Parameters{dataParam} @@ -94,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, PayloadParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -106,31 +96,34 @@ var _ tools.Tool = Tool{} type Tool struct { Config PayloadParams parameters.Parameters - - database *mongo.Database - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + if len(params) == 0 { return nil, errors.New("no input found") } paramsMap := params.AsMap() - var jsonData, ok = paramsMap[paramDataKey].(string) + jsonData, ok := paramsMap[paramDataKey].(string) if !ok { return nil, errors.New("no input found") } var data = []any{} - err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) + err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) if err != nil { return nil, err } - res, err := t.database.Collection(t.Collection).InsertMany(ctx, data, options.InsertMany()) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertMany(ctx, data, options.InsertMany()) if err != nil { return nil, err } @@ -154,14 +147,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany_test.go b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany_test.go index 2d8a1cdb30..19ac3ce0c1 100644 --- a/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany_test.go +++ b/internal/tools/mongodb/mongodbinsertmany/mongodbinsertmany_test.go @@ -39,6 +39,30 @@ func TestParseFromYamlMongoQuery(t *testing.T) { { desc: "basic example", in: ` + tools: + example_tool: + kind: mongodb-insert-many + source: my-instance + description: some description + database: test_db + collection: test_coll + `, + want: server.ToolConfigs{ + "example_tool": mongodbinsertmany.Config{ + Name: "example_tool", + Kind: "mongodb-insert-many", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Description: "some description", + Canonical: false, + }, + }, + }, + { + desc: "true canonical", + in: ` tools: example_tool: kind: mongodb-insert-many @@ -61,6 +85,31 @@ func TestParseFromYamlMongoQuery(t *testing.T) { }, }, }, + { + desc: "false canonical", + in: ` + tools: + example_tool: + kind: mongodb-insert-many + source: my-instance + description: some description + database: test_db + collection: test_coll + canonical: false + `, + want: server.ToolConfigs{ + "example_tool": mongodbinsertmany.Config{ + Name: "example_tool", + Kind: "mongodb-insert-many", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Description: "some description", + Canonical: false, + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go index 8939921995..037a01dda7 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -46,6 +45,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -54,7 +57,7 @@ type Config struct { Description string `yaml:"description" validate:"required"` Database string `yaml:"database" validate:"required"` Collection string `yaml:"collection" validate:"required"` - Canonical bool `yaml:"canonical" validate:"required"` //i want to force the user to choose + Canonical bool `yaml:"canonical"` } // validate interface @@ -65,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - payloadParams := parameters.NewStringParameterWithRequired(dataParamsKey, "the JSON payload to insert, should be a JSON object", true) allParameters := parameters.Parameters{payloadParams} @@ -95,7 +86,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, PayloadParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -107,29 +97,32 @@ var _ tools.Tool = Tool{} type Tool struct { Config PayloadParams parameters.Parameters `yaml:"payloadParams" validate:"required"` - - database *mongo.Database - manifest tools.Manifest - mcpManifest tools.McpManifest + manifest tools.Manifest + mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + if len(params) == 0 { return nil, errors.New("no input found") } // use the first, assume it's a string - var jsonData, ok = params[0].Value.(string) + jsonData, ok := params[0].Value.(string) if !ok { return nil, errors.New("no input found") } var data any - err := bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) + err = bson.UnmarshalExtJSON([]byte(jsonData), t.Canonical, &data) if err != nil { return nil, err } - res, err := t.database.Collection(t.Collection).InsertOne(ctx, data, options.InsertOne()) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).InsertOne(ctx, data, options.InsertOne()) if err != nil { return nil, err } @@ -153,14 +146,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go index 2e4563efd4..a61dde20b0 100644 --- a/internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go +++ b/internal/tools/mongodb/mongodbinsertone/mongodbinsertone_test.go @@ -39,6 +39,30 @@ func TestParseFromYamlMongoQuery(t *testing.T) { { desc: "basic example", in: ` + tools: + example_tool: + kind: mongodb-insert-one + source: my-instance + description: some description + database: test_db + collection: test_coll + `, + want: server.ToolConfigs{ + "example_tool": mongodbinsertone.Config{ + Name: "example_tool", + Kind: "mongodb-insert-one", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Canonical: false, + Description: "some description", + }, + }, + }, + { + desc: "true canonical", + in: ` tools: example_tool: kind: mongodb-insert-one @@ -61,6 +85,31 @@ func TestParseFromYamlMongoQuery(t *testing.T) { }, }, }, + { + desc: "false canonical", + in: ` + tools: + example_tool: + kind: mongodb-insert-one + source: my-instance + description: some description + database: test_db + collection: test_coll + canonical: false + `, + want: server.ToolConfigs{ + "example_tool": mongodbinsertone.Config{ + Name: "example_tool", + Kind: "mongodb-insert-one", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Canonical: false, + Description: "some description", + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go index 31c91368ca..1d38f1ff26 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -44,6 +43,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -56,7 +59,7 @@ type Config struct { FilterParams parameters.Parameters `yaml:"filterParams"` UpdatePayload string `yaml:"updatePayload" validate:"required"` UpdateParams parameters.Parameters `yaml:"updateParams" validate:"required"` - Canonical bool `yaml:"canonical" validate:"required"` + Canonical bool `yaml:"canonical"` Upsert bool `yaml:"upsert"` } @@ -68,18 +71,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.UpdateParams) @@ -103,7 +94,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -114,14 +104,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - database *mongo.Database + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateManyFilter", t.FilterPayload, paramsMap) @@ -146,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to unmarshal update string: %w", err) } - res, err := t.database.Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) if err != nil { return nil, fmt.Errorf("error updating collection: %w", err) } @@ -170,14 +163,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany_test.go b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany_test.go index 209fffc565..b7353afe08 100644 --- a/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany_test.go +++ b/internal/tools/mongodb/mongodbupdatemany/mongodbupdatemany_test.go @@ -40,6 +40,62 @@ func TestParseFromYamlMongoQuery(t *testing.T) { { desc: "basic example", in: ` + tools: + example_tool: + kind: mongodb-update-many + source: my-instance + description: some description + database: test_db + collection: test_coll + filterPayload: | + { name: {{json .name}} } + filterParams: + - name: name + type: string + description: small description + updatePayload: | + { $set: { name: {{json .name}} } } + updateParams: + - name: name + type: string + description: small description + `, + want: server.ToolConfigs{ + "example_tool": mongodbupdatemany.Config{ + Name: "example_tool", + Kind: "mongodb-update-many", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + FilterPayload: "{ name: {{json .name}} }\n", + FilterParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + UpdatePayload: "{ $set: { name: {{json .name}} } }\n", + UpdateParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + Description: "some description", + Canonical: false, + }, + }, + }, + { + desc: "true canonical", + in: ` tools: example_tool: kind: mongodb-update-many @@ -94,6 +150,63 @@ func TestParseFromYamlMongoQuery(t *testing.T) { }, }, }, + { + desc: "false canonical", + in: ` + tools: + example_tool: + kind: mongodb-update-many + source: my-instance + description: some description + database: test_db + collection: test_coll + filterPayload: | + { name: {{json .name}} } + filterParams: + - name: name + type: string + description: small description + canonical: false + updatePayload: | + { $set: { name: {{json .name}} } } + updateParams: + - name: name + type: string + description: small description + `, + want: server.ToolConfigs{ + "example_tool": mongodbupdatemany.Config{ + Name: "example_tool", + Kind: "mongodb-update-many", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + FilterPayload: "{ name: {{json .name}} }\n", + FilterParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + UpdatePayload: "{ $set: { name: {{json .name}} } }\n", + UpdateParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + Description: "some description", + Canonical: false, + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go index 4615c71be6..397b521198 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - mongosrc "github.com/googleapis/genai-toolbox/internal/sources/mongodb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.mongodb.org/mongo-driver/bson" @@ -44,6 +43,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + MongoClient() *mongo.Client +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -57,7 +60,7 @@ type Config struct { UpdatePayload string `yaml:"updatePayload" validate:"required"` UpdateParams parameters.Parameters `yaml:"updateParams" validate:"required"` - Canonical bool `yaml:"canonical" validate:"required"` + Canonical bool `yaml:"canonical"` Upsert bool `yaml:"upsert"` } @@ -69,18 +72,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(*mongosrc.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `mongodb`", kind) - } - // Create a slice for all parameters allParameters := slices.Concat(cfg.FilterParams, cfg.UpdateParams) @@ -104,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, AllParams: allParameters, - database: s.Client.Database(cfg.Database), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, }, nil @@ -115,14 +105,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters - - database *mongo.Database + AllParams parameters.Parameters manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() filterString, err := parameters.PopulateTemplateWithJSON("MongoDBUpdateOneFilter", t.FilterPayload, paramsMap) @@ -147,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to unmarshal update string: %w", err) } - res, err := t.database.Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) + res, err := source.MongoClient().Database(t.Database).Collection(t.Collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(t.Upsert)) if err != nil { return nil, fmt.Errorf("error updating collection: %w", err) } @@ -171,14 +164,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go index 4f0398bd1e..b450892fb2 100644 --- a/internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go +++ b/internal/tools/mongodb/mongodbupdateone/mongodbupdateone_test.go @@ -40,6 +40,123 @@ func TestParseFromYamlMongoQuery(t *testing.T) { { desc: "basic example", in: ` + tools: + example_tool: + kind: mongodb-update-one + source: my-instance + description: some description + database: test_db + collection: test_coll + filterPayload: | + { name: {{json .name}} } + filterParams: + - name: name + type: string + description: small description + updatePayload: | + { $set : { item: {{json .item}} } } + updateParams: + - name: item + type: string + description: small description + upsert: true + `, + want: server.ToolConfigs{ + "example_tool": mongodbupdateone.Config{ + Name: "example_tool", + Kind: "mongodb-update-one", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Canonical: false, + FilterPayload: "{ name: {{json .name}} }\n", + FilterParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + UpdatePayload: "{ $set : { item: {{json .item}} } }\n", + UpdateParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "item", + Type: "string", + Desc: "small description", + }, + }, + }, + Upsert: true, + Description: "some description", + }, + }, + }, + { + desc: "false canonical", + in: ` + tools: + example_tool: + kind: mongodb-update-one + source: my-instance + description: some description + database: test_db + collection: test_coll + filterPayload: | + { name: {{json .name}} } + filterParams: + - name: name + type: string + description: small description + updatePayload: | + { $set : { item: {{json .item}} } } + updateParams: + - name: item + type: string + description: small description + canonical: false + upsert: true + `, + want: server.ToolConfigs{ + "example_tool": mongodbupdateone.Config{ + Name: "example_tool", + Kind: "mongodb-update-one", + Source: "my-instance", + AuthRequired: []string{}, + Database: "test_db", + Collection: "test_coll", + Canonical: false, + FilterPayload: "{ name: {{json .name}} }\n", + FilterParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "name", + Type: "string", + Desc: "small description", + }, + }, + }, + UpdatePayload: "{ $set : { item: {{json .item}} } }\n", + UpdateParams: parameters.Parameters{ + ¶meters.StringParameter{ + CommonParameter: parameters.CommonParameter{ + Name: "item", + Type: "string", + Desc: "small description", + }, + }, + }, + Upsert: true, + Description: "some description", + }, + }, + }, + { + desc: "true canonical", + in: ` tools: example_tool: kind: mongodb-update-one diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index bd9bd65049..ddfbdb089e 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -49,12 +47,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -92,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -104,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -125,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MSSQLDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -183,14 +165,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index c838ea4fe9..29fbea4498 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -296,12 +294,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -318,18 +310,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -341,7 +321,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -353,14 +332,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() outputFormat, _ := paramsMap["output_format"].(string) @@ -373,7 +355,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT sql.Named("output_format", outputFormat), } - rows, err := t.Db.QueryContext(ctx, listTablesStatement, namedArgs...) + rows, err := source.MSSQLDB().QueryContext(ctx, listTablesStatement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -391,7 +373,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT values[i] = &rawValues[i] } - var out []any + out := []any{} for rows.Next() { err = rows.Scan(values...) if err != nil { @@ -428,14 +410,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 1aff17e8f4..0e621b7417 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -22,8 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" - "github.com/googleapis/genai-toolbox/internal/sources/mssql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -48,12 +46,6 @@ type compatibleSource interface { MSSQLDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmssql.Source{} -var _ compatibleSource = &mssql.Source{} - -var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -96,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.MSSQLDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -108,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -140,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } } - rows, err := t.Db.QueryContext(ctx, newStatement, namedArgs...) + rows, err := source.MSSQLDB().QueryContext(ctx, newStatement, namedArgs...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -198,14 +180,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index 6d0ec4a72a..5198602d70 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -21,9 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -51,13 +48,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind, mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -95,7 +73,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,14 +84,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -128,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.MySQLPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -197,14 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go new file mode 100644 index 0000000000..3458a6ed83 --- /dev/null +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -0,0 +1,164 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysqlgetqueryplan + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind string = "mysql-get-query-plan" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + MySQLPool() *sql.DB +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + sqlParameter := parameters.NewStringParameter("sql_statement", "The sql statement to explain.") + params := parameters.Parameters{sqlParameter} + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + sql, ok := paramsMap["sql_statement"].(string) + if !ok { + return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql_statement"]) + } + + // Log the query executed for debugging. + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error getting logger: %s", err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) + + query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) + results, err := source.MySQLPool().QueryContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + var plan string + if results.Next() { + if err := results.Scan(&plan); err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + } else { + return nil, fmt.Errorf("no query plan returned") + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + var out any + if err := json.Unmarshal([]byte(plan), &out); err != nil { + return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(_ tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go new file mode 100644 index 0000000000..b06248dbaf --- /dev/null +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan_test.go @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysqlgetqueryplan_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlgetqueryplan" +) + +func TestParseFromYamlGetQueryPlan(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: mysql-get-query-plan + source: my-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": mysqlgetqueryplan.Config{ + Name: "example_tool", + Kind: "mysql-get-query-plan", + Source: "my-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index 2a7373f231..323d582d32 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -111,12 +111,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -138,11 +132,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) if !ok { return nil, fmt.Errorf("no source named %q configured", cfg.Source) } - // verify the source is compatible - s, ok := rawS.(compatibleSource) + _, ok = rawS.(compatibleSource) if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source) } allParameters := parameters.Parameters{ @@ -165,7 +158,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -180,13 +172,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest statement string } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() duration, ok := paramsMap["min_duration_secs"].(int) @@ -205,7 +201,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, t.statement)) - results, err := t.Pool.QueryContext(ctx, t.statement, duration, duration, limit) + results, err := source.MySQLPool().QueryContext(ctx, t.statement, duration, duration, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -273,14 +269,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index e058669777..a0bc1b8f66 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -71,12 +69,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -93,18 +85,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_schema", "", "(Optional) The database where fragmentation check is to be executed. Check all tables visible to the current user if not specified"), parameters.NewStringParameterWithDefault("table_name", "", "(Optional) Name of the table to be checked. Check all tables visible to the current user if not specified."), @@ -116,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -130,12 +109,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) @@ -162,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTableFragmentationStatement)) - results, err := t.Pool.QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) + results, err := source.MySQLPool().QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +213,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index b2ffc60714..66928b75fa 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -201,12 +199,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -223,18 +215,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -246,7 +226,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -260,12 +239,16 @@ type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) @@ -277,7 +260,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - results, err := t.Pool.QueryContext(ctx, listTablesStatement, tableNames, outputFormat) + results, err := source.MySQLPool().QueryContext(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -300,7 +283,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to get column types: %w", err) } - var out []any + out := []any{} for results.Next() { err := results.Scan(values...) if err != nil { @@ -345,14 +328,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 4878f20028..522b180acd 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -21,8 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -72,12 +70,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &cloudsqlmysql.Source{} - -var compatibleSources = [...]string{mysql.SourceKind, cloudsqlmysql.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -94,18 +86,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_schema", "", "(Optional) The database where the check is to be performed. Check all tables visible to the current user if not specified"), parameters.NewIntParameterWithDefault("limit", 50, "(Optional) Max rows to return, default is 50"), @@ -115,7 +95,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup t := Tool{ Config: cfg, - Pool: s.MySQLPool(), allParams: allParameters, manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -129,12 +108,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() table_schema, ok := paramsMap["table_schema"].(string) @@ -153,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTablesMissingUniqueIndexesStatement)) - results, err := t.Pool.QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) + results, err := source.MySQLPool().QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -221,14 +204,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index a84239c738..edf5f65db1 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -21,9 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" - "github.com/googleapis/genai-toolbox/internal/sources/mindsdb" - "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -49,13 +46,6 @@ type compatibleSource interface { MySQLPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &cloudsqlmysql.Source{} -var _ compatibleSource = &mysql.Source{} -var _ compatibleSource = &mindsdb.Source{} - -var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind, mindsdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -75,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -98,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.MySQLPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -110,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -130,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.MySQLPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -198,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go index a38f7b86ce..5f5c4ce05b 100644 --- a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go +++ b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go @@ -19,7 +19,6 @@ import ( "fmt" "github.com/goccy/go-yaml" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/neo4j/neo4j-go-driver/v5/neo4j" @@ -49,11 +48,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// validate compatible sources are still compatible -var _ compatibleSource = &neo4jsc.Source{} - -var compatibleSources = [...]string{neo4jsc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,25 +66,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,18 +82,20 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Driver neo4j.DriverWithContext - Database string manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() - config := neo4j.ExecuteQueryWithDatabase(t.Database) - results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, t.Driver, t.Statement, paramsMap, + config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) + results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, source.Neo4jDriver(), t.Statement, paramsMap, neo4j.EagerResultTransformer, config) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -149,14 +131,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index 7d809a9135..0bf2b8f34e 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -20,7 +20,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" @@ -49,11 +48,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// validate compatible sources are still compatible -var _ compatibleSource = &neo4jsc.Source{} - -var compatibleSources = [...]string{neo4jsc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,19 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - var s compatibleSource - s, ok = rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - cypherParameter := parameters.NewStringParameter("cypher", "The cypher to execute.") dryRunParameter := parameters.NewBooleanParameterWithDefault( "dry_run", @@ -99,8 +80,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), classifier: classifier.NewQueryClassifier(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -114,14 +93,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Database string - Driver neo4j.DriverWithContext classifier *classifier.QueryClassifier manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() cypherStr, ok := paramsMap["cypher"].(string) if !ok { @@ -152,8 +134,8 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT cypherStr = "EXPLAIN " + cypherStr } - config := neo4j.ExecuteQueryWithDatabase(t.Database) - results, err := neo4j.ExecuteQuery(ctx, t.Driver, cypherStr, nil, + config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) + results, err := neo4j.ExecuteQuery(ctx, source.Neo4jDriver(), cypherStr, nil, neo4j.EagerResultTransformer, config) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -208,8 +190,8 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // Recursive function to add plan children @@ -234,6 +216,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/neo4j/neo4jschema/neo4jschema.go b/internal/tools/neo4j/neo4jschema/neo4jschema.go index b813ba5bcf..24b97cefb2 100644 --- a/internal/tools/neo4j/neo4jschema/neo4jschema.go +++ b/internal/tools/neo4j/neo4jschema/neo4jschema.go @@ -22,7 +22,6 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/cache" "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" @@ -58,12 +57,6 @@ type compatibleSource interface { Neo4jDatabase() string } -// Statically verify that our compatible source implementation is valid. -var _ compatibleSource = &neo4jsc.Source{} - -// compatibleSources lists the kinds of sources that are compatible with this tool. -var compatibleSources = [...]string{neo4jsc.SourceKind} - // Config holds the configuration settings for the Neo4j schema tool. // These settings are typically read from a YAML file. type Config struct { @@ -85,17 +78,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize sets up the tool with its dependencies and returns a ready-to-use Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // Verify that the specified source exists. - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // Verify the source is of a compatible kind. - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -109,8 +91,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Finish tool setup by creating the Tool instance. t := Tool{ Config: cfg, - Driver: s.Neo4jDriver(), - Database: s.Neo4jDatabase(), cache: cache.NewCache(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, @@ -125,17 +105,19 @@ var _ tools.Tool = Tool{} // It holds the Neo4j driver, database information, and a cache for the schema. type Tool struct { Config - Driver neo4j.DriverWithContext - Database string - cache *cache.Cache - + cache *cache.Cache manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the tool's main logic: fetching the Neo4j schema. // It first checks the cache for a valid schema before extracting it from the database. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Check if a valid schema is already in the cache. if cachedSchema, ok := t.cache.Get("schema"); ok { if schema, ok := cachedSchema.(*types.SchemaInfo); ok { @@ -144,7 +126,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // If not cached, extract the schema from the database. - schema, err := t.extractSchema(ctx) + schema, err := t.extractSchema(ctx, source) if err != nil { return nil, fmt.Errorf("failed to extract database schema: %w", err) } @@ -176,16 +158,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // checkAPOCProcedures verifies if essential APOC procedures are available in the database. // It returns true only if all required procedures are found. -func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) { +func (t Tool) checkAPOCProcedures(ctx context.Context, source compatibleSource) (bool, error) { proceduresToCheck := []string{"apoc.meta.schema", "apoc.meta.cypher.types"} - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) // This query efficiently counts how many of the specified procedures exist. @@ -218,7 +200,7 @@ func (t Tool) checkAPOCProcedures(ctx context.Context) (bool, error) { // extractSchema orchestrates the concurrent extraction of different parts of the database schema. // It runs several extraction tasks in parallel for efficiency. -func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { +func (t Tool) extractSchema(ctx context.Context, source compatibleSource) (*types.SchemaInfo, error) { schema := &types.SchemaInfo{} var mu sync.Mutex @@ -230,7 +212,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "database-info", fn: func() error { - dbInfo, err := t.extractDatabaseInfo(ctx) + dbInfo, err := t.extractDatabaseInfo(ctx, source) if err != nil { return fmt.Errorf("failed to extract database info: %w", err) } @@ -244,7 +226,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { name: "schema-extraction", fn: func() error { // Check if APOC procedures are available. - hasAPOC, err := t.checkAPOCProcedures(ctx) + hasAPOC, err := t.checkAPOCProcedures(ctx, source) if err != nil { return fmt.Errorf("failed to check APOC procedures: %w", err) } @@ -255,9 +237,9 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { // Use APOC if available for a more detailed schema; otherwise, use native queries. if hasAPOC { - nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx) + nodeLabels, relationships, stats, err = t.GetAPOCSchema(ctx, source) } else { - nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, 100) + nodeLabels, relationships, stats, err = t.GetSchemaWithoutAPOC(ctx, source, 100) } if err != nil { return fmt.Errorf("failed to get schema: %w", err) @@ -274,7 +256,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "constraints", fn: func() error { - constraints, err := t.extractConstraints(ctx) + constraints, err := t.extractConstraints(ctx, source) if err != nil { return fmt.Errorf("failed to extract constraints: %w", err) } @@ -287,7 +269,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { { name: "indexes", fn: func() error { - indexes, err := t.extractIndexes(ctx) + indexes, err := t.extractIndexes(ctx, source) if err != nil { return fmt.Errorf("failed to extract indexes: %w", err) } @@ -329,7 +311,7 @@ func (t Tool) extractSchema(ctx context.Context) (*types.SchemaInfo, error) { } // GetAPOCSchema extracts schema information using the APOC library, which provides detailed metadata. -func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { +func (t Tool) GetAPOCSchema(ctx context.Context, source compatibleSource) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { var nodeLabels []types.NodeLabel var relationships []types.Relationship stats := &types.Statistics{ @@ -444,7 +426,7 @@ func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Rel fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) @@ -461,7 +443,7 @@ func (t Tool) GetAPOCSchema(ctx context.Context) ([]types.NodeLabel, []types.Rel // GetSchemaWithoutAPOC extracts schema information using native Cypher queries. // This serves as a fallback for databases without APOC installed. -func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { +func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, source compatibleSource, sampleSize int) ([]types.NodeLabel, []types.Relationship, *types.Statistics, error) { nodePropsMap := make(map[string]map[string]map[string]bool) relPropsMap := make(map[string]map[string]map[string]bool) nodeCounts := make(map[string]int64) @@ -609,7 +591,7 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types fn func(session neo4j.SessionWithContext) error }) { defer wg.Done() - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) if err := task.fn(session); err != nil { handleError(fmt.Errorf("task %s failed: %w", task.name, err)) @@ -627,8 +609,8 @@ func (t Tool) GetSchemaWithoutAPOC(ctx context.Context, sampleSize int) ([]types } // extractDatabaseInfo retrieves general information about the Neo4j database instance. -func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractDatabaseInfo(ctx context.Context, source compatibleSource) (*types.DatabaseInfo, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "CALL dbms.components() YIELD name, versions, edition", nil) @@ -649,8 +631,8 @@ func (t Tool) extractDatabaseInfo(ctx context.Context) (*types.DatabaseInfo, err } // extractConstraints fetches all schema constraints from the database. -func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractConstraints(ctx context.Context, source compatibleSource) ([]types.Constraint, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW CONSTRAINTS", nil) @@ -678,8 +660,8 @@ func (t Tool) extractConstraints(ctx context.Context) ([]types.Constraint, error } // extractIndexes fetches all schema indexes from the database. -func (t Tool) extractIndexes(ctx context.Context) ([]types.Index, error) { - session := t.Driver.NewSession(ctx, neo4j.SessionConfig{DatabaseName: t.Database}) +func (t Tool) extractIndexes(ctx context.Context, source compatibleSource) ([]types.Index, error) { + session := source.Neo4jDriver().NewSession(ctx, neo4j.SessionConfig{DatabaseName: source.Neo4jDatabase()}) defer session.Close(ctx) result, err := session.Run(ctx, "SHOW INDEXES", nil) @@ -711,6 +693,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go index c493240ecf..fa8d7a96a9 100644 --- a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go +++ b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -39,11 +38,6 @@ type compatibleSource interface { OceanBasePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oceanbase.Source{} - -var compatibleSources = [...]string{oceanbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -89,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.OceanBasePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -101,22 +82,25 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the SQL statement provided in the parameters. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sliceParams := params.AsSlice() sqlStr, ok := sliceParams[0].(string) if !ok { return nil, fmt.Errorf("unable to get cast %s", sliceParams[0]) } - results, err := t.Pool.QueryContext(ctx, sqlStr) + results, err := source.OceanBasePool().QueryContext(ctx, sqlStr) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -189,14 +173,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go index 00793824e2..10a4dc17de 100644 --- a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go +++ b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oceanbase" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -39,11 +38,6 @@ type compatibleSource interface { OceanBasePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oceanbase.Source{} - -var compatibleSources = [...]string{oceanbase.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("unable to process parameters: %w", err) @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.OceanBasePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,15 +87,18 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } // Invoke executes the SQL statement with the provided parameters. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -127,7 +111,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.OceanBasePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -200,14 +184,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index 2a40112c76..447f9362e9 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -11,7 +11,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -37,11 +36,6 @@ type compatibleSource interface { OracleDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oracle.Source{} - -var compatibleSources = [...]string{oracle.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -58,18 +52,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL to execute.") params := parameters.Parameters{sqlParameter} @@ -79,7 +61,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.OracleDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -91,14 +72,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sqlParam, ok := paramsMap["sql"].(string) if !ok { @@ -110,9 +94,9 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT if err != nil { return nil, fmt.Errorf("error getting logger: %s", err) } - logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sqlParam)) + logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam) - results, err := t.Pool.QueryContext(ctx, sqlParam) + results, err := source.OracleDB().QueryContext(ctx, sqlParam) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go new file mode 100644 index 0000000000..834d3d6981 --- /dev/null +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql_test.go @@ -0,0 +1,82 @@ +// Copyright © 2025, Oracle and/or its affiliates. + +package oracleexecutesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql" +) + +func TestParseFromYamlOracleExecuteSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with auth", + in: ` + tools: + run_adhoc_query: + kind: oracle-execute-sql + source: my-oracle-instance + description: Executes arbitrary SQL statements like INSERT or UPDATE. + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "run_adhoc_query": oracleexecutesql.Config{ + Name: "run_adhoc_query", + Kind: "oracle-execute-sql", + Source: "my-oracle-instance", + Description: "Executes arbitrary SQL statements like INSERT or UPDATE.", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example without authRequired", + in: ` + tools: + run_simple_update: + kind: oracle-execute-sql + source: db-dev + description: Runs a simple update operation. + `, + want: server.ToolConfigs{ + "run_simple_update": oracleexecutesql.Config{ + Name: "run_simple_update", + Kind: "oracle-execute-sql", + Source: "db-dev", + Description: "Runs a simple update operation.", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/oracle/oraclesql/oraclesql.go b/internal/tools/oracle/oraclesql/oraclesql.go index d41f62c742..1ba87b47bd 100644 --- a/internal/tools/oracle/oraclesql/oraclesql.go +++ b/internal/tools/oracle/oraclesql/oraclesql.go @@ -11,7 +11,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/oracle" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -36,11 +35,6 @@ type compatibleSource interface { OracleDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &oracle.Source{} - -var compatibleSources = [...]string{oracle.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -60,18 +54,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("error processing parameters: %w", err) @@ -83,7 +65,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - DB: s.OracleDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -95,14 +76,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - DB *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -120,7 +104,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } fmt.Printf("\n") - rows, err := t.DB.QueryContext(ctx, newStatement, sliceParams...) + rows, err := source.OracleDB().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -230,14 +214,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/oracle/oraclesql/oraclesql_test.go b/internal/tools/oracle/oraclesql/oraclesql_test.go new file mode 100644 index 0000000000..2ba0a7321c --- /dev/null +++ b/internal/tools/oracle/oraclesql/oraclesql_test.go @@ -0,0 +1,85 @@ +// Copyright © 2025, Oracle and/or its affiliates. +package oraclesql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql" +) + +func TestParseFromYamlOracleSql(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example with statement and auth", + in: ` + tools: + get_user_by_id: + kind: oracle-sql + source: my-oracle-instance + description: Retrieves user details by ID. + statement: "SELECT id, name, email FROM users WHERE id = :1" + authRequired: + - my-google-auth-service + `, + want: server.ToolConfigs{ + "get_user_by_id": oraclesql.Config{ + Name: "get_user_by_id", + Kind: "oracle-sql", + Source: "my-oracle-instance", + Description: "Retrieves user details by ID.", + Statement: "SELECT id, name, email FROM users WHERE id = :1", + AuthRequired: []string{"my-google-auth-service"}, + }, + }, + }, + { + desc: "example with parameters and template parameters", + in: ` + tools: + get_orders: + kind: oracle-sql + source: db-prod + description: Gets orders for a customer with optional filtering. + statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status" + `, + want: server.ToolConfigs{ + "get_orders": oraclesql.Config{ + Name: "get_orders", + Kind: "oracle-sql", + Source: "db-prod", + Description: "Gets orders for a customer with optional filtering.", + Statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go index b3472b3178..4e8a0a29ce 100644 --- a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go +++ b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -62,13 +59,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -85,30 +75,16 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} - description := cfg.Description - if description == "" { - description = "Fetches the current state of the PostgreSQL server, returning the version, whether it's a replica, uptime duration, maximum connection limit, number of current connections, number of active connections, and the percentage of connections in use." + if cfg.Description == "" { + cfg.Description = "Fetches the current state of the PostgreSQL server, returning the version, whether it's a replica, uptime duration, maximum connection limit, number of current connections, number of active connections, and the percentage of connections in use." } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -124,7 +100,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -133,10 +108,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sliceParams := params.AsSlice() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - results, err := t.pool.Query(ctx, databaseOverviewStatement, sliceParams...) + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, databaseOverviewStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -181,10 +167,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index 92fe4ea7aa..73afd2a6ee 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -50,13 +47,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -73,18 +63,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -94,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *pgxpool.Pool + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -126,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.Query(ctx, sql) + results, err := source.PostgresPool().Query(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -170,14 +150,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go index c0448fc17b..f96654fbc6 100644 --- a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go +++ b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -67,13 +64,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -90,18 +80,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "public", "Optional: The schema name in which the table is present."), parameters.NewStringParameterWithRequired("table_name", "Required: The table name in which the column is present.", true), @@ -117,11 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -136,20 +111,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -158,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, getColumnCardinality, sliceParams...) + results, err := source.PostgresPool().Query(ctx, getColumnCardinality, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,13 +175,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go index a4a104189a..6ad5bff569 100644 --- a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go +++ b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -71,13 +68,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -94,18 +84,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("min_duration", "1 minute", "Optional: Only show queries running at least this long (e.g., '1 minute', '1 second', '2 seconds')."), parameters.NewStringParameterWithDefault("exclude_application_names", "", "Optional: A comma-separated list of application names to exclude from the query results. This is useful for filtering out queries from specific applications (e.g., 'psql', 'pgAdmin', 'DBeaver'). The match is case-sensitive. Whitespace around commas and names is automatically handled. If this parameter is omitted, no applications are excluded."), @@ -118,7 +96,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -135,12 +112,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -149,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listActiveQueriesStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listActiveQueriesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -194,14 +175,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go index 23309e7374..1440509cbb 100644 --- a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go +++ b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -58,13 +55,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,25 +71,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ Config: cfg, - Pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -115,13 +92,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - results, err := t.Pool.Query(ctx, listAvailableExtensionsQuery) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -165,14 +146,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go new file mode 100644 index 0000000000..27cc16c1ed --- /dev/null +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go @@ -0,0 +1,257 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistdatabasestats + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-database-stats" + +// SQL query to list database statistics +const listDatabaseStats = ` + WITH database_stats AS ( + SELECT + s.datname AS database_name, + -- Database Metadata + d.datallowconn AS is_connectable, + pg_get_userbyid(d.datdba) AS database_owner, + ts.spcname AS default_tablespace, + + -- Cache Performance + CASE + WHEN (s.blks_hit + s.blks_read) = 0 THEN 0 + ELSE round((s.blks_hit * 100.0) / (s.blks_hit + s.blks_read), 2) + END AS cache_hit_ratio_percent, + s.blks_read AS blocks_read_from_disk, + s.blks_hit AS blocks_hit_in_cache, + + -- Transaction Throughput + s.xact_commit, + s.xact_rollback, + round(s.xact_rollback * 100.0 / (s.xact_commit + s.xact_rollback + 1), 2) AS rollback_ratio_percent, + + -- Tuple Activity + s.tup_returned AS rows_returned_by_queries, + s.tup_fetched AS rows_fetched_by_scans, + s.tup_inserted, + s.tup_updated, + s.tup_deleted, + + -- Temporary File Usage + s.temp_files, + s.temp_bytes AS temp_size_bytes, + + -- Conflicts & Deadlocks + s.conflicts, + s.deadlocks, + + -- General Info + s.numbackends AS active_connections, + s.stats_reset AS statistics_last_reset, + pg_database_size(s.datid) AS database_size_bytes + FROM + pg_stat_database s + JOIN + pg_database d ON d.oid = s.datid + JOIN + pg_tablespace ts ON ts.oid = d.dattablespace + WHERE + -- Exclude cloudsql internal databases + s.datname NOT IN ('cloudsqladmin') + -- Exclude template databases if not requested + AND ( $2::boolean IS TRUE OR d.datistemplate IS FALSE ) + ) + SELECT * + FROM database_stats + WHERE + ($1::text IS NULL OR database_name LIKE '%' || $1::text || '%') + AND ($3::text IS NULL OR database_owner LIKE '%' || $3::text || '%') + AND ($4::text IS NULL OR default_tablespace LIKE '%' || $4::text || '%') + ORDER BY + CASE WHEN $5::text = 'size' THEN database_size_bytes END DESC, + CASE WHEN $5::text = 'commit' THEN xact_commit END DESC, + database_name + LIMIT COALESCE($6::int, 10); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("database_name", "", "Optional: A specific database name pattern to search for."), + parameters.NewBooleanParameterWithDefault("include_templates", false, "Optional: Whether to include template databases in the results."), + parameters.NewStringParameterWithDefault("database_owner", "", "Optional: A specific database owner name pattern to search for."), + parameters.NewStringParameterWithDefault("default_tablespace", "", "Optional: A specific default tablespace name pattern to search for."), + parameters.NewStringParameterWithDefault("order_by", "", "Optional: The field to order the results by. Valid values are 'size' and 'commit'."), + parameters.NewIntParameterWithDefault("limit", 10, "Optional: The maximum number of rows to return."), + } + description := cfg.Description + if description == "" { + description = + "Lists the key performance and activity statistics for each PostgreSQL database" + + "in the instance, offering insights into cache efficiency, transaction throughput" + + "row-level activity, temporary file " + + "usage, and contention. " + + "It returns: the database name, whether the database is connectable, " + + "database owner, default tablespace name, the percentage of data blocks " + + "found in the buffer cache rather than being read from disk (a higher " + + "value indicates better cache performance), the total number of disk " + + "blocks read from disk, the total number of times disk blocks were found " + + "already in the cache; the total number of committed transactions, the " + + "total number of rolled back transactions, the percentage of rolled back " + + "transactions compared to the total number of completed transactions, the " + + "total number of rows returned by queries, the total number of live rows " + + "fetched by scans, the total number of rows inserted, the total number " + + "of rows updated, the total number of rows deleted, the number of " + + "temporary files created by queries, the total size of all temporary " + + "files created by queries in bytes, the number of query cancellations due " + + "to conflicts with recovery, the number of deadlocks detected, the current " + + "number of active connections to the database, the timestamp of the " + + "last statistics reset, and total database size in bytes." + } + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: allParameters.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listDatabaseStats, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats_test.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats_test.go new file mode 100644 index 0000000000..760370f630 --- /dev/null +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistdatabasestats_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistdatabasestats" +) + +func TestParseFromYamlPostgresListDatabaseStats(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-database-stats + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslistdatabasestats.Config{ + Name: "example_tool", + Kind: "postgres-list-database-stats", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-database-stats + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslistdatabasestats.Config{ + Name: "example_tool", + Kind: "postgres-list-database-stats", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go index aba26fd339..0f85a0e46c 100644 --- a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go +++ b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -59,7 +56,8 @@ const listIndexesStatement = ` ON i.oid = s.indexrelid WHERE t.relkind = 'r' - AND s.schemaname NOT IN ('pg_catalog', 'information_schema') + AND s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND s.schemaname NOT LIKE 'pg_temp_%' ) SELECT * FROM IndexDetails @@ -67,11 +65,12 @@ const listIndexesStatement = ` ($1::text IS NULL OR schema_name LIKE '%' || $1 || '%') AND ($2::text IS NULL OR table_name LIKE '%' || $2 || '%') AND ($3::text IS NULL OR index_name LIKE '%' || $3 || '%') + AND ($4::boolean IS NOT TRUE OR is_used IS FALSE) ORDER BY schema_name, table_name, index_name - LIMIT COALESCE($4::int, 50); + LIMIT COALESCE($5::int, 50); ` func init() { @@ -92,13 +91,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -115,35 +107,23 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "", "Optional: a text to filter results by schema name. The input is used within a LIKE clause."), parameters.NewStringParameterWithDefault("table_name", "", "Optional: a text to filter results by table name. The input is used within a LIKE clause."), parameters.NewStringParameterWithDefault("index_name", "", "Optional: a text to filter results by index name. The input is used within a LIKE clause."), + parameters.NewBooleanParameterWithDefault("only_unused", false, "Optional: If true, only returns indexes that have never been used."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return. Default is 50"), } - description := cfg.Description - if description == "" { - description = "Lists available user indexes in the database, excluding system schemas (pg_catalog, information_schema). For each index, the following properties are returned: schema name, table name, index name, index type (access method), a boolean indicating if it's a unique index, a boolean indicating if it's for a primary key, the index definition, index size in bytes, the number of index scans, the number of index tuples read, the number of table tuples fetched via index scans, and a boolean indicating if the index has been used at least once." + + if cfg.Description == "" { + cfg.Description = "Lists available user indexes in the database, excluding system schemas (pg_catalog, information_schema). For each index, the following properties are returned: schema name, table name, index name, index type (access method), a boolean indicating if it's a unique index, a boolean indicating if it's for a primary key, the index definition, index size in bytes, the number of index scans, the number of index tuples read, the number of table tuples fetched via index scans, and a boolean indicating if the index has been used at least once." } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -159,7 +139,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -168,10 +147,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sliceParams := params.AsSlice() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - results, err := t.pool.Query(ctx, listIndexesStatement, sliceParams...) + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listIndexesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -216,10 +206,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go index 2a30799569..effa306f46 100644 --- a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go +++ b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -69,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,25 +82,12 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - params := parameters.Parameters{} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) // finish tool setup t := Tool{ Config: cfg, - Pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: params.Manifest(), @@ -126,13 +103,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - results, err := t.Pool.Query(ctx, listAvailableExtensionsQuery) +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -176,14 +157,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go index dc3c26f3a6..881962e2be 100644 --- a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go +++ b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -69,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -92,18 +82,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} paramManifest := allParameters.Manifest() @@ -115,11 +93,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -134,20 +109,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -156,7 +132,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listLocks, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listLocks, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -193,13 +169,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go new file mode 100644 index 0000000000..05fccc3d6e --- /dev/null +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistpgsettings + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-pg-settings" + +const listPgSettingsStatement = ` + SELECT + name, + setting AS current_value, + unit, + short_desc, + source, + CASE context + WHEN 'postmaster' THEN 'Yes' + WHEN 'sighup' THEN 'No (Reload sufficient)' + ELSE 'No' + END + AS requires_restart + FROM pg_settings + WHERE ($1::text IS NULL OR name LIKE '%' || $1::text || '%') + ORDER BY name + LIMIT COALESCE($2::int, 50); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("setting_name", "", "Optional: A specific configuration parameter name pattern to search for."), + parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), + } + description := cfg.Description + if description == "" { + description = "Lists configuration parameters for the postgres server ordered lexicographically, with a default limit of 50 rows. It returns the parameter name, its current setting, unit of measurement, a short description, the source of the current setting (e.g., default, configuration file, session), and whether a restart is required when the parameter value is changed." + } + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: allParameters.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listPgSettingsStatement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings_test.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings_test.go new file mode 100644 index 0000000000..a2aa9fe78b --- /dev/null +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistpgsettings_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpgsettings" +) + +func TestParseFromYamlPostgreslistPgSettings(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-pg-settings + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslistpgsettings.Config{ + Name: "example_tool", + Kind: "postgres-list-pg-settings", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-pg-settings + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslistpgsettings.Config{ + Name: "example_tool", + Kind: "postgres-list-pg-settings", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go new file mode 100644 index 0000000000..9b1d48fdea --- /dev/null +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go @@ -0,0 +1,197 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistpublicationtables + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-publication-tables" + +const listPublicationTablesStatement = ` + WITH + publication_details AS ( + SELECT + pt.pubname AS publication_name, + pt.schemaname AS schema_name, + pt.tablename AS table_name, + -- Definition details + p.puballtables AS publishes_all_tables, + p.pubinsert AS publishes_inserts, + p.pubupdate AS publishes_updates, + p.pubdelete AS publishes_deletes, + p.pubtruncate AS publishes_truncates, + -- Owner information + pg_catalog.pg_get_userbyid(p.pubowner) AS publication_owner + FROM pg_catalog.pg_publication_tables pt + JOIN pg_catalog.pg_publication p + ON pt.pubname = p.pubname + ) + SELECT * + FROM publication_details + WHERE + (NULLIF(TRIM($1::text), '') IS NULL OR table_name = ANY(regexp_split_to_array(TRIM($1::text), '\s*,\s*'))) + AND (NULLIF(TRIM($2::text), '') IS NULL OR publication_name = ANY(regexp_split_to_array(TRIM($2::text), '\s*,\s*'))) + AND (NULLIF(TRIM($3::text), '') IS NULL OR schema_name = ANY(regexp_split_to_array(TRIM($3::text), '\s*,\s*'))) + ORDER BY + publication_name, schema_name, table_name + LIMIT COALESCE($4::int, 50); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("table_names", "", "Optional: Filters by a comma-separated list of table names."), + parameters.NewStringParameterWithDefault("publication_names", "", "Optional: Filters by a comma-separated list of publication names."), + parameters.NewStringParameterWithDefault("schema_names", "", "Optional: Filters by a comma-separated list of schema names."), + parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), + } + description := cfg.Description + if description == "" { + description = "Lists all publication tables in the database. Returns the publication name, schema name, and table name, along with definition details indicating if it publishes all tables, whether it replicates inserts, updates, deletes, or truncates, and the publication owner." + } + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: allParameters.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listPublicationTablesStatement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + if err := results.Err(); err != nil { + return err.Error(), fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables_test.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables_test.go new file mode 100644 index 0000000000..211567c4a8 --- /dev/null +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistpublicationtables_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistpublicationtables" +) + +func TestParseFromYamlPostgresListPublicationTables(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-publication-tables + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslistpublicationtables.Config{ + Name: "example_tool", + Kind: "postgres-list-publication-tables", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-publication-tables + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslistpublicationtables.Config{ + Name: "example_tool", + Kind: "postgres-list-publication-tables", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go index e7652b0f67..e2a26e496b 100644 --- a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go +++ b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -68,13 +65,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,18 +81,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("database_name", "", "Optional: The database name to list query stats for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of results to return. Defaults to 50."), @@ -117,11 +95,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -136,20 +111,20 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -158,7 +133,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listQueryStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listQueryStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,13 +174,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles.go b/internal/tools/postgres/postgreslistroles/postgreslistroles.go new file mode 100644 index 0000000000..160aebb31a --- /dev/null +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles.go @@ -0,0 +1,209 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistroles + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-roles" + +const listRolesStatement = ` + WITH RoleDetails AS ( + SELECT + r.rolname AS role_name, + r.oid AS oid, + r.rolconnlimit AS connection_limit, + r.rolsuper AS is_superuser, + r.rolinherit AS inherits_privileges, + r.rolcreaterole AS can_create_roles, + r.rolcreatedb AS can_create_db, + r.rolcanlogin AS can_login, + r.rolreplication AS is_replication_role, + r.rolbypassrls AS bypass_rls, + r.rolvaliduntil AS valid_until, + -- List of roles that belong to this role (Direct Members) + ARRAY( + SELECT m_r.rolname + FROM pg_auth_members pam + JOIN pg_roles m_r ON pam.member = m_r.oid + WHERE pam.roleid = r.oid + ) AS direct_members, + -- List of roles that this role belongs to (Member Of) + ARRAY( + SELECT g_r.rolname + FROM pg_auth_members pam + JOIN pg_roles g_r ON pam.roleid = g_r.oid + WHERE pam.member = r.oid + ) AS member_of + FROM pg_roles r + -- Exclude system and internal roles + WHERE r.rolname NOT LIKE 'cloudsql%' + AND r.rolname NOT LIKE 'alloydb_%' + AND r.rolname NOT LIKE 'pg_%' + ) + SELECT * + FROM RoleDetails + WHERE + ($1::text IS NULL OR role_name LIKE '%' || $1 || '%') + ORDER BY role_name + LIMIT COALESCE($2::int, 50); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("role_name", "", "Optional: a text to filter results by role name. The input is used within a LIKE clause."), + parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return. Default is 10"), + } + + description := cfg.Description + if description == "" { + description = "Lists all the user-created roles in the instance . It returns the role name, Object ID, the maximum number of concurrent connections the role can make, along with boolean indicators for: superuser status, privilege inheritance from member roles, ability to create roles, ability to create databases, ability to log in, replication privilege, and the ability to bypass row-level security, the password expiration timestamp, a list of direct members belonging to this role, and a list of other roles/groups that this role is a member of." + } + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: description, + Parameters: allParameters.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listRolesStatement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles_test.go b/internal/tools/postgres/postgreslistroles/postgreslistroles_test.go new file mode 100644 index 0000000000..cf4249f6ff --- /dev/null +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslistroles_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles" +) + +func TestParseFromYamlPostgresListRoles(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-roles + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslistroles.Config{ + Name: "example_tool", + Kind: "postgres-list-roles", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-roles + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslistroles.Config{ + Name: "example_tool", + Kind: "postgres-list-roles", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go index 5058a0122e..729a4af1b4 100644 --- a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go +++ b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -32,28 +29,28 @@ const kind string = "postgres-list-schemas" const listSchemasStatement = ` WITH - schema_grants AS ( - SELECT schema_oid, jsonb_object_agg(grantee, privileges) AS grants - FROM - ( - SELECT - n.oid AS schema_oid, - CASE - WHEN p.grantee = 0 THEN 'PUBLIC' - ELSE pg_catalog.pg_get_userbyid(p.grantee) - END - AS grantee, - jsonb_agg(p.privilege_type ORDER BY p.privilege_type) AS privileges - FROM pg_catalog.pg_namespace n, aclexplode(n.nspacl) p - WHERE n.nspacl IS NOT NULL - GROUP BY n.oid, grantee - ) permissions_by_grantee - GROUP BY schema_oid - ), - all_schemas AS ( - SELECT - n.nspname AS schema_name, - pg_catalog.pg_get_userbyid(n.nspowner) AS owner, + schema_grants AS ( + SELECT schema_oid, jsonb_object_agg(grantee, privileges) AS grants + FROM + ( + SELECT + n.oid AS schema_oid, + CASE + WHEN p.grantee = 0 THEN 'PUBLIC' + ELSE pg_catalog.pg_get_userbyid(p.grantee) + END + AS grantee, + jsonb_agg(p.privilege_type ORDER BY p.privilege_type) AS privileges + FROM pg_catalog.pg_namespace n, aclexplode(n.nspacl) p + WHERE n.nspacl IS NOT NULL + GROUP BY n.oid, grantee + ) permissions_by_grantee + GROUP BY schema_oid + ), + all_schemas AS ( + SELECT + n.nspname AS schema_name, + pg_catalog.pg_get_userbyid(n.nspowner) AS owner, COALESCE(sg.grants, '{}'::jsonb) AS grants, ( SELECT COUNT(*) @@ -67,18 +64,21 @@ const listSchemasStatement = ` ) AS views, (SELECT COUNT(*) FROM pg_catalog.pg_proc p WHERE p.pronamespace = n.oid) AS functions - FROM pg_catalog.pg_namespace n - LEFT JOIN schema_grants sg - ON n.oid = sg.schema_oid - ) + FROM pg_catalog.pg_namespace n + LEFT JOIN schema_grants sg + ON n.oid = sg.schema_oid + ) SELECT * FROM all_schemas - -- Exclude system schemas and temporary schemas created per session. + -- Exclude system schemas and temporary schemas created per session. WHERE - schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') - AND schema_name NOT LIKE 'pg_temp_%' - AND ($1::text IS NULL OR schema_name LIKE '%' || $1::text || '%') - ORDER BY schema_name; + schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND schema_name NOT LIKE 'pg_temp_%' + AND schema_name NOT LIKE 'pg_toast_temp_%' + AND ($1::text IS NULL OR schema_name ILIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR owner ILIKE '%' || $2::text || '%') + ORDER BY schema_name + LIMIT COALESCE($3::int, NULL); ` func init() { @@ -99,13 +99,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -122,32 +115,21 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), + parameters.NewStringParameterWithDefault("owner", "", "Optional: A specific schema owner name pattern to search for."), + parameters.NewIntParameterWithDefault("limit", 10, "Optional: The maximum number of schemas to return."), } - description := cfg.Description - if description == "" { - description = "Lists all schemas in the database ordered by schema name and excluding system and temporary schemas. It returns the schema name, schema owner, grants, number of functions, number of tables and number of views within each schema." + + if cfg.Description == "" { + cfg.Description = "Lists all schemas in the database ordered by schema name and excluding system and temporary schemas. It returns the schema name, schema owner, grants, number of functions, number of tables and number of views within each schema." } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -163,15 +145,25 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sliceParams := params.AsSlice() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - results, err := t.pool.Query(ctx, listSchemasStatement, sliceParams...) + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listSchemasStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -216,14 +208,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go index 9b42e4f527..a8877ab6f7 100644 --- a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go +++ b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -32,9 +29,9 @@ const kind string = "postgres-list-sequences" const listSequencesStatement = ` SELECT - sequencename, - schemaname, - sequenceowner, + sequencename as sequence_name, + schemaname as schema_name, + sequenceowner as sequence_owner, data_type, start_value, min_value, @@ -45,7 +42,7 @@ const listSequencesStatement = ` WHERE ($1::text IS NULL OR schemaname LIKE '%' || $1 || '%') AND ($2::text IS NULL OR sequencename LIKE '%' || $2 || '%') - ORDER BY schemaname, sequencename + ORDER BY schema_name, sequence_name LIMIT COALESCE($3::int, 50); ` @@ -68,13 +65,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -91,34 +81,21 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ - parameters.NewStringParameterWithDefault("schemaname", "", "Optional: A specific schema name pattern to search for."), - parameters.NewStringParameterWithDefault("sequencename", "", "Optional: A specific sequence name pattern to search for."), + parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), + parameters.NewStringParameterWithDefault("sequence_name", "", "Optional: A specific sequence name pattern to search for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return. Default is 50"), } - description := cfg.Description - if description == "" { - description = "Lists sequences in the database. Returns sequence name, schema name, sequence owner, data type of the sequence, starting value, minimum value, maximum value of the sequence, the value by which the sequence is incremented, and the last value generated by the sequence in the current session" + + if cfg.Description == "" { + cfg.Description = "Lists sequences in the database. Returns sequence name, schema name, sequence owner, data type of the sequence, starting value, minimum value, maximum value of the sequence, the value by which the sequence is incremented, and the last value generated by the sequence in the current session" } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -134,7 +111,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -143,10 +119,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sliceParams := params.AsSlice() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - results, err := t.pool.Query(ctx, listSequencesStatement, sliceParams...) + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listSequencesStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -191,10 +178,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttables/postgreslisttables.go b/internal/tools/postgres/postgreslisttables/postgreslisttables.go index 8266d10b7c..264983edb6 100644 --- a/internal/tools/postgres/postgreslisttables/postgreslisttables.go +++ b/internal/tools/postgres/postgreslisttables/postgreslisttables.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -50,7 +47,7 @@ const listTablesStatement = ` WHERE t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p') AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names - AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast','google_ml') AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' ), columns_info AS ( @@ -126,13 +123,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -149,18 +139,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), @@ -171,7 +149,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -184,14 +161,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() tableNames, ok := paramsMap["table_names"].(string) @@ -203,14 +183,14 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - results, err := t.Pool.Query(ctx, listTablesStatement, tableNames, outputFormat) + results, err := source.PostgresPool().Query(ctx, listTablesStatement, tableNames, outputFormat) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } defer results.Close() fields := results.FieldDescriptions() - var out []map[string]any + out := []map[string]any{} for results.Next() { values, err := results.Values() @@ -247,14 +227,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go new file mode 100644 index 0000000000..8e2d0e700d --- /dev/null +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go @@ -0,0 +1,194 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslisttablespaces + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-tablespaces" + +const listTableSpacesStatement = ` + WITH + tablespace_info AS ( + SELECT + spcname AS tablespace_name, + pg_catalog.pg_get_userbyid(spcowner) AS owner_name, + CASE + WHEN pg_catalog.has_tablespace_privilege(oid, 'CREATE') THEN pg_tablespace_size(oid) + ELSE NULL + END AS size_in_bytes, + oid, + spcacl, + spcoptions + FROM + pg_tablespace + ) + SELECT * + FROM + tablespace_info + WHERE + ($1::text IS NULL OR tablespace_name LIKE '%' || $1::text || '%') + ORDER BY + tablespace_name + LIMIT COALESCE($2::int, 50); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("tablespace_name", "", "Optional: a text to filter results by tablespace name. The input is used within a LIKE clause."), + parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), + } + description := cfg.Description + if description == "" { + description = "Lists all tablespaces in the database. Returns the tablespace name, owner name, size in bytes(if the current user has CREATE privileges on the tablespace, otherwise NULL), internal object ID, the access control list regarding permissions, and any specific tablespace options." + } + mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: allParameters.Manifest(), + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + tablespaceName, ok := paramsMap["tablespace_name"].(string) + if !ok { + return nil, fmt.Errorf("invalid 'tablespace_name' parameter; expected a string") + } + limit, ok := paramsMap["limit"].(int) + if !ok { + return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") + } + + results, err := source.PostgresPool().Query(ctx, listTableSpacesStatement, tablespaceName, limit) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces_test.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces_test.go new file mode 100644 index 0000000000..0d28d5abf3 --- /dev/null +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslisttablespaces_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces" +) + +func TestParseFromYamlPostgresListTablespaces(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-tablespaces + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslisttablespaces.Config{ + Name: "example_tool", + Kind: "postgres-list-tablespaces", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-tablespaces + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslisttablespaces.Config{ + Name: "example_tool", + Kind: "postgres-list-tablespaces", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go new file mode 100644 index 0000000000..69a953e654 --- /dev/null +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go @@ -0,0 +1,221 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslisttablestats + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5/pgxpool" +) + +const kind string = "postgres-list-table-stats" + +const listTableStats = ` + WITH table_stats AS ( + SELECT + s.schemaname AS schema_name, + s.relname AS table_name, + pg_catalog.pg_get_userbyid(c.relowner) AS owner, + pg_total_relation_size(s.relid) AS total_size_bytes, + s.seq_scan, + s.idx_scan, + -- Ratio of index scans to total scans + CASE + WHEN (s.seq_scan + s.idx_scan) = 0 THEN 0 + ELSE round((s.idx_scan * 100.0) / (s.seq_scan + s.idx_scan), 2) + END AS idx_scan_ratio_percent, + s.n_live_tup AS live_rows, + s.n_dead_tup AS dead_rows, + -- Percentage of rows that are "dead" (bloat) + CASE + WHEN (s.n_live_tup + s.n_dead_tup) = 0 THEN 0 + ELSE round((s.n_dead_tup * 100.0) / (s.n_live_tup + s.n_dead_tup), 2) + END AS dead_row_ratio_percent, + s.n_tup_ins, + s.n_tup_upd, + s.n_tup_del, + s.last_vacuum, + s.last_autovacuum, + s.last_autoanalyze + FROM pg_stat_all_tables s + JOIN pg_catalog.pg_class c ON s.relid = c.oid + ) + SELECT * + FROM table_stats + WHERE + ($1::text IS NULL OR schema_name LIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR table_name LIKE '%' || $2::text || '%') + AND ($3::text IS NULL OR owner LIKE '%' || $3::text || '%') + ORDER BY + CASE + WHEN $4::text = 'size' THEN total_size_bytes + WHEN $4::text = 'dead_rows' THEN dead_rows + WHEN $4::text = 'seq_scan' THEN seq_scan + WHEN $4::text = 'idx_scan' THEN idx_scan + ELSE seq_scan + END DESC + LIMIT COALESCE($5::int, 50); +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + PostgresPool() *pgxpool.Pool +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("schema_name", "public", "Optional: A specific schema name to filter by"), + parameters.NewStringParameterWithRequired("table_name", "Optional: A specific table name to filter by", false), + parameters.NewStringParameterWithRequired("owner", "Optional: A specific owner to filter by", false), + parameters.NewStringParameterWithRequired("sort_by", "Optional: The column to sort by", false), + parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of results to return"), + } + paramManifest := allParameters.Manifest() + + if cfg.Description == "" { + cfg.Description = `Lists the user table statistics in the database ordered by number of + sequential scans with a default limit of 50 rows. Returns the following + columns: schema name, table name, table size in bytes, number of + sequential scans, number of index scans, idx_scan_ratio_percent (showing + the percentage of total scans that utilized an index, where a low ratio + indicates missing or ineffective indexes), number of live rows, number + of dead rows, dead_row_ratio_percent (indicating potential table bloat), + total number of rows inserted, updated, and deleted, the timestamps + for the last_vacuum, last_autovacuum, and last_autoanalyze operations.` + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + return Tool{ + Config: cfg, + allParams: allParameters, + manifest: tools.Manifest{ + Description: cfg.Description, + Parameters: paramManifest, + AuthRequired: cfg.AuthRequired, + }, + mcpManifest: mcpManifest, + }, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listTableStats, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.allParams, data, claims) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats_test.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats_test.go new file mode 100644 index 0000000000..cfaac3eda5 --- /dev/null +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats_test.go @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgreslisttablestats_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats" +) + +func TestParseFromYamlPostgresListTableStats(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-table-stats + source: my-postgres-instance + description: some description + authRequired: + - my-google-auth-service + - other-auth-service + `, + want: server.ToolConfigs{ + "example_tool": postgreslisttablestats.Config{ + Name: "example_tool", + Kind: "postgres-list-table-stats", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + }, + }, + }, + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: postgres-list-table-stats + source: my-postgres-instance + description: some description + `, + want: server.ToolConfigs{ + "example_tool": postgreslisttablestats.Config{ + Name: "example_tool", + Kind: "postgres-list-table-stats", + Source: "my-postgres-instance", + Description: "some description", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } + +} diff --git a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go index f1fe102911..8fc4944f73 100644 --- a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go +++ b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -94,13 +91,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -117,35 +107,22 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("trigger_name", "", "Optional: A specific trigger name pattern to search for."), parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name pattern to search for."), parameters.NewStringParameterWithDefault("table_name", "", "Optional: A specific table name pattern to search for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), } - description := cfg.Description - if description == "" { - description = "Lists all non-internal triggers in a database. Returns trigger name, schema name, table name, whether its enabled or disabled, timing (e.g BEFORE/AFTER of the event), the events that cause the trigger to fire such as INSERT, UPDATE, or DELETE, whether the trigger activates per ROW or per STATEMENT, the handler function executed by the trigger and full definition." + + if cfg.Description == "" { + cfg.Description = "Lists all non-internal triggers in a database. Returns trigger name, schema name, table name, whether its enabled or disabled, timing (e.g BEFORE/AFTER of the event), the events that cause the trigger to fire such as INSERT, UPDATE, or DELETE, whether the trigger activates per ROW or per STATEMENT, the handler function executed by the trigger and full definition." } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: allParameters.Manifest(), @@ -161,7 +138,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } @@ -170,10 +146,21 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - sliceParams := params.AsSlice() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - results, err := t.pool.Query(ctx, listTriggersStatement, sliceParams...) + paramsMap := params.AsMap() + + newParams, err := parameters.GetParams(t.allParams, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + + results, err := source.PostgresPool().Query(ctx, listTriggersStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -218,10 +205,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslistviews/postgreslistviews.go b/internal/tools/postgres/postgreslistviews/postgreslistviews.go index e4d046589d..d0aa2438d1 100644 --- a/internal/tools/postgres/postgreslistviews/postgreslistviews.go +++ b/internal/tools/postgres/postgreslistviews/postgreslistviews.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -31,13 +28,24 @@ import ( const kind string = "postgres-list-views" const listViewsStatement = ` - SELECT schemaname, viewname, viewowner - FROM pg_views - WHERE - schemaname NOT IN ('pg_catalog', 'information_schema') - AND ($1::text IS NULL OR viewname LIKE '%' || $1::text || '%') - ORDER BY viewname - LIMIT COALESCE($2::int, 50); + WITH list_views AS ( + SELECT + schemaname AS schema_name, + viewname AS view_name, + viewowner AS owner_name, + definition + FROM pg_views + ) + SELECT * + FROM list_views + WHERE + schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') + AND schema_name NOT LIKE 'pg_temp_%' + AND ($1::text IS NULL OR view_name ILIKE '%' || $1::text || '%') + AND ($2::text IS NULL OR schema_name ILIKE '%' || $2::text || '%') + ORDER BY + schema_name, view_name + LIMIT COALESCE($3::int, 50); ` func init() { @@ -58,13 +66,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -81,34 +82,21 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ - parameters.NewStringParameterWithDefault("viewname", "", "Optional: A specific view name to search for."), + parameters.NewStringParameterWithDefault("view_name", "", "Optional: A specific view name to search for."), + parameters.NewStringParameterWithDefault("schema_name", "", "Optional: A specific schema name to search for."), parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of rows to return."), } paramManifest := allParameters.Manifest() - description := cfg.Description - if description == "" { - description = "Lists views in the database from pg_views with a default limit of 50 rows. Returns schemaname, viewname and the ownername." + if cfg.Description == "" { + cfg.Description = "Lists views in the database from pg_views with a default limit of 50 rows. Returns schemaname, viewname, ownername and the definition." } - mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) // finish tool setup return Tool{ Config: cfg, allParams: allParameters, - pool: s.PostgresPool(), manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -124,12 +112,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -138,7 +130,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, listViewsStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, listViewsStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -183,14 +175,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go index bba3ecea28..1b2434679d 100644 --- a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go +++ b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -76,13 +73,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -99,18 +89,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault("min_duration", "5 minutes", "Optional: Only show transactions running at least this long (e.g., '1 minute', '15 minutes', '30 seconds')."), parameters.NewIntParameterWithDefault("limit", 20, "Optional: The maximum number of long-running transactions to return. Defaults to 20."), @@ -125,11 +103,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -144,20 +119,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -166,7 +142,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, longRunningTransactions, sliceParams...) + results, err := source.PostgresPool().Query(ctx, longRunningTransactions, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -203,13 +179,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go index 569cd339f1..4280f1a0a3 100644 --- a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go +++ b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -66,13 +63,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -89,18 +79,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters := parameters.Parameters{} paramManifest := allParameters.Manifest() @@ -112,11 +90,8 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // finish tool setup return Tool{ - name: cfg.Name, - kind: cfg.Kind, - authRequired: cfg.AuthRequired, - allParams: allParameters, - pool: s.PostgresPool(), + Config: cfg, + allParams: allParameters, manifest: tools.Manifest{ Description: cfg.Description, Parameters: paramManifest, @@ -131,20 +106,21 @@ var _ tools.Tool = Tool{} type Tool struct { Config - name string `yaml:"name"` - kind string `yaml:"kind"` - authRequired []string `yaml:"authRequired"` - allParams parameters.Parameters `yaml:"allParams"` - pool *pgxpool.Pool - manifest tools.Manifest - mcpManifest tools.McpManifest + allParams parameters.Parameters `yaml:"allParams"` + manifest tools.Manifest + mcpManifest tools.McpManifest } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newParams, err := parameters.GetParams(t.allParams, paramsMap) @@ -153,7 +129,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.pool.Query(ctx, replicationStats, sliceParams...) + results, err := source.PostgresPool().Query(ctx, replicationStats, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -190,13 +166,13 @@ func (t Tool) McpManifest() tools.McpManifest { } func (t Tool) Authorized(verifiedAuthServices []string) bool { - return tools.IsAuthorized(t.authRequired, verifiedAuthServices) + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/postgres/postgressql/postgressql.go b/internal/tools/postgres/postgressql/postgressql.go index 597790c6d1..1de22a5a82 100644 --- a/internal/tools/postgres/postgressql/postgressql.go +++ b/internal/tools/postgres/postgressql/postgressql.go @@ -20,9 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" - "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" - "github.com/googleapis/genai-toolbox/internal/sources/postgres" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" @@ -48,13 +45,6 @@ type compatibleSource interface { PostgresPool() *pgxpool.Pool } -// validate compatible sources are still compatible -var _ compatibleSource = &alloydbpg.Source{} -var _ compatibleSource = &cloudsqlpg.Source{} -var _ compatibleSource = &postgres.Source{} - -var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -97,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.PostgresPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -109,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -128,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Pool.Query(ctx, newStatement, sliceParams...) + results, err := source.PostgresPool().Query(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -172,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/redis/redis.go b/internal/tools/redis/redis.go index c8cc21f2a2..6995163a6a 100644 --- a/internal/tools/redis/redis.go +++ b/internal/tools/redis/redis.go @@ -46,11 +46,6 @@ type compatibleSource interface { RedisClient() redissrc.RedisClient } -// validate compatible sources are still compatible -var _ compatibleSource = &redissrc.Source{} - -var compatibleSources = [...]string{redissrc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -69,24 +64,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Client: s.RedisClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -98,13 +80,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Client redissrc.RedisClient manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + cmds, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { return nil, fmt.Errorf("error replacing commands' parameters: %s", err) @@ -113,7 +98,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT // Execute commands responses := make([]*redis.Cmd, len(cmds)) for i, cmd := range cmds { - responses[i] = t.Client.Do(ctx, cmd...) + responses[i] = source.RedisClient().Do(ctx, cmd...) } // Parse responses out := make([]any, len(t.Commands)) @@ -165,8 +150,8 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } // replaceCommandsParams is a helper function to replace parameters in the commands @@ -207,6 +192,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/common/urls.go b/internal/tools/serverlessspark/common/urls.go new file mode 100644 index 0000000000..3b52235992 --- /dev/null +++ b/internal/tools/serverlessspark/common/urls.go @@ -0,0 +1,91 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "fmt" + "net/url" + "regexp" + "time" + + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" +) + +const ( + logTimeBufferBefore = 1 * time.Minute + logTimeBufferAfter = 10 * time.Minute +) + +var batchFullNameRegex = regexp.MustCompile(`projects/(?P[^/]+)/locations/(?P[^/]+)/batches/(?P[^/]+)`) + +// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name. +func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) { + matches := batchFullNameRegex.FindStringSubmatch(batchName) + if len(matches) < 4 { + return "", "", "", fmt.Errorf("failed to parse batch name: %s", batchName) + } + return matches[1], matches[2], matches[3], nil +} + +// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page. +func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) { + projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) + if err != nil { + return "", err + } + return BatchConsoleURL(projectID, location, batchID), nil +} + +// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range. +func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) { + projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName()) + if err != nil { + return "", err + } + createTime := batchPb.GetCreateTime().AsTime() + stateTime := batchPb.GetStateTime().AsTime() + return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil +} + +// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page. +func BatchConsoleURL(projectID, location, batchID string) string { + return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID) +} + +// BatchLogsURL builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range. +// +// The implementation adds some buffer before and after the provided times. +func BatchLogsURL(projectID, location, batchID string, startTime, endTime time.Time) string { + advancedFilterTemplate := `resource.type="cloud_dataproc_batch" +resource.labels.project_id="%s" +resource.labels.location="%s" +resource.labels.batch_id="%s"` + advancedFilter := fmt.Sprintf(advancedFilterTemplate, projectID, location, batchID) + if !startTime.IsZero() { + actualStart := startTime.Add(-1 * logTimeBufferBefore) + advancedFilter += fmt.Sprintf("\ntimestamp>=\"%s\"", actualStart.Format(time.RFC3339Nano)) + } + if !endTime.IsZero() { + actualEnd := endTime.Add(logTimeBufferAfter) + advancedFilter += fmt.Sprintf("\ntimestamp<=\"%s\"", actualEnd.Format(time.RFC3339Nano)) + } + + v := url.Values{} + v.Add("resource", "cloud_dataproc_batch/batch_id/"+batchID) + v.Add("advancedFilter", advancedFilter) + v.Add("project", projectID) + + return "https://console.cloud.google.com/logs/viewer?" + v.Encode() +} diff --git a/internal/tools/serverlessspark/common/urls_test.go b/internal/tools/serverlessspark/common/urls_test.go new file mode 100644 index 0000000000..c8d9e07200 --- /dev/null +++ b/internal/tools/serverlessspark/common/urls_test.go @@ -0,0 +1,119 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + "time" + + "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestExtractBatchDetails_Success(t *testing.T) { + batchName := "projects/my-project/locations/us-central1/batches/my-batch" + projectID, location, batchID, err := ExtractBatchDetails(batchName) + if err != nil { + t.Errorf("ExtractBatchDetails() error = %v, want no error", err) + return + } + wantProject := "my-project" + wantLocation := "us-central1" + wantBatchID := "my-batch" + if projectID != wantProject { + t.Errorf("ExtractBatchDetails() projectID = %v, want %v", projectID, wantProject) + } + if location != wantLocation { + t.Errorf("ExtractBatchDetails() location = %v, want %v", location, wantLocation) + } + if batchID != wantBatchID { + t.Errorf("ExtractBatchDetails() batchID = %v, want %v", batchID, wantBatchID) + } +} + +func TestExtractBatchDetails_Failure(t *testing.T) { + batchName := "invalid-name" + _, _, _, err := ExtractBatchDetails(batchName) + wantErr := "failed to parse batch name: invalid-name" + if err == nil || err.Error() != wantErr { + t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr) + } +} + +func TestBatchConsoleURL(t *testing.T) { + got := BatchConsoleURL("my-project", "us-central1", "my-batch") + want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project" + if got != want { + t.Errorf("BatchConsoleURL() = %v, want %v", got, want) + } +} + +func TestBatchLogsURL(t *testing.T) { + startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC) + endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC) + got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime) + want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" + + "resource.type%3D%22cloud_dataproc_batch%22" + + "%0Aresource.labels.project_id%3D%22my-project%22" + + "%0Aresource.labels.location%3D%22us-central1%22" + + "%0Aresource.labels.batch_id%3D%22my-batch%22" + + "%0Atimestamp%3E%3D%222025-10-01T04%3A59%3A00Z%22" + // Minus 1 minute + "%0Atimestamp%3C%3D%222025-10-01T06%3A10%3A00Z%22" + // Plus 10 minutes + "&project=my-project" + + "&resource=cloud_dataproc_batch%2Fbatch_id%2Fmy-batch" + if got != want { + t.Errorf("BatchLogsURL() = %v, want %v", got, want) + } +} + +func TestBatchConsoleURLFromProto(t *testing.T) { + batchPb := &dataprocpb.Batch{ + Name: "projects/my-project/locations/us-central1/batches/my-batch", + } + got, err := BatchConsoleURLFromProto(batchPb) + if err != nil { + t.Fatalf("BatchConsoleURLFromProto() error = %v", err) + } + want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project" + if got != want { + t.Errorf("BatchConsoleURLFromProto() = %v, want %v", got, want) + } +} + +func TestBatchLogsURLFromProto(t *testing.T) { + createTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC) + stateTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC) + batchPb := &dataprocpb.Batch{ + Name: "projects/my-project/locations/us-central1/batches/my-batch", + CreateTime: timestamppb.New(createTime), + StateTime: timestamppb.New(stateTime), + } + got, err := BatchLogsURLFromProto(batchPb) + if err != nil { + t.Fatalf("BatchLogsURLFromProto() error = %v", err) + } + want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" + + "resource.type%3D%22cloud_dataproc_batch%22" + + "%0Aresource.labels.project_id%3D%22my-project%22" + + "%0Aresource.labels.location%3D%22us-central1%22" + + "%0Aresource.labels.batch_id%3D%22my-batch%22" + + "%0Atimestamp%3E%3D%222025-10-01T04%3A59%3A00Z%22" + // Minus 1 minute + "%0Atimestamp%3C%3D%222025-10-01T06%3A10%3A00Z%22" + // Plus 10 minutes + "&project=my-project" + + "&resource=cloud_dataproc_batch%2Fbatch_id%2Fmy-batch" + if got != want { + t.Errorf("BatchLogsURLFromProto() = %v, want %v", got, want) + } +} diff --git a/internal/tools/serverlessspark/createbatch/config.go b/internal/tools/serverlessspark/createbatch/config.go new file mode 100644 index 0000000000..0bb3575a39 --- /dev/null +++ b/internal/tools/serverlessspark/createbatch/config.go @@ -0,0 +1,99 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package createbatch + +import ( + "context" + "encoding/json" + "fmt" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1" + dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// unmarshalProto is a helper function to unmarshal a generic interface{} into a proto.Message. +func unmarshalProto(data any, m proto.Message) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal to JSON: %w", err) + } + return protojson.Unmarshal(jsonData, m) +} + +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + +// Config is a common config that can be used with any type of create batch tool. However, each tool +// will still need its own config type, embedding this Config, so it can provide a type-specific +// Initialize implementation. +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description"` + RuntimeConfig *dataprocpb.RuntimeConfig `yaml:"runtimeConfig"` + EnvironmentConfig *dataprocpb.EnvironmentConfig `yaml:"environmentConfig"` + AuthRequired []string `yaml:"authRequired"` +} + +func NewConfig(ctx context.Context, name string, decoder *yaml.Decoder) (Config, error) { + // Use a temporary struct to decode the YAML, so that we can handle the proto + // conversion for RuntimeConfig and EnvironmentConfig. + var ymlCfg struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Source string `yaml:"source"` + Description string `yaml:"description"` + RuntimeConfig any `yaml:"runtimeConfig"` + EnvironmentConfig any `yaml:"environmentConfig"` + AuthRequired []string `yaml:"authRequired"` + } + + if err := decoder.DecodeContext(ctx, &ymlCfg); err != nil { + return Config{}, err + } + + cfg := Config{ + Name: name, + Kind: ymlCfg.Kind, + Source: ymlCfg.Source, + Description: ymlCfg.Description, + AuthRequired: ymlCfg.AuthRequired, + } + + if ymlCfg.RuntimeConfig != nil { + rc := &dataprocpb.RuntimeConfig{} + if err := unmarshalProto(ymlCfg.RuntimeConfig, rc); err != nil { + return Config{}, fmt.Errorf("failed to unmarshal runtimeConfig: %w", err) + } + cfg.RuntimeConfig = rc + } + + if ymlCfg.EnvironmentConfig != nil { + ec := &dataprocpb.EnvironmentConfig{} + if err := unmarshalProto(ymlCfg.EnvironmentConfig, ec); err != nil { + return Config{}, fmt.Errorf("failed to unmarshal environmentConfig: %w", err) + } + cfg.EnvironmentConfig = ec + } + + return cfg, nil +} diff --git a/internal/tools/serverlessspark/createbatch/tool.go b/internal/tools/serverlessspark/createbatch/tool.go new file mode 100644 index 0000000000..66702533da --- /dev/null +++ b/internal/tools/serverlessspark/createbatch/tool.go @@ -0,0 +1,167 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package createbatch + +import ( + "context" + "encoding/json" + "fmt" + "time" + + dataprocpb "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +type BatchBuilder interface { + Parameters() parameters.Parameters + BuildBatch(params parameters.ParamValues) (*dataprocpb.Batch, error) +} + +func NewTool(cfg Config, originalCfg tools.ToolConfig, srcs map[string]sources.Source, builder BatchBuilder) (*Tool, error) { + desc := cfg.Description + if desc == "" { + desc = fmt.Sprintf("Creates a Serverless Spark (aka Dataproc Serverless) %s operation.", cfg.Kind) + } + + allParameters := builder.Parameters() + inputSchema, _ := allParameters.McpManifest() + + mcpManifest := tools.McpManifest{ + Name: cfg.Name, + Description: desc, + InputSchema: inputSchema, + } + + return &Tool{ + Config: cfg, + originalConfig: originalCfg, + Builder: builder, + manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, + mcpManifest: mcpManifest, + Parameters: allParameters, + }, nil +} + +type Tool struct { + Config + originalConfig tools.ToolConfig + Builder BatchBuilder + manifest tools.Manifest + mcpManifest tools.McpManifest + Parameters parameters.Parameters +} + +func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + client := source.GetBatchControllerClient() + + batch, err := t.Builder.BuildBatch(params) + if err != nil { + return nil, fmt.Errorf("failed to build batch: %w", err) + } + + if t.RuntimeConfig != nil { + batch.RuntimeConfig = proto.Clone(t.RuntimeConfig).(*dataprocpb.RuntimeConfig) + } + + if t.EnvironmentConfig != nil { + batch.EnvironmentConfig = proto.Clone(t.EnvironmentConfig).(*dataprocpb.EnvironmentConfig) + } + + // Common override for version if present in params + paramMap := params.AsMap() + if version, ok := paramMap["version"].(string); ok && version != "" { + if batch.RuntimeConfig == nil { + batch.RuntimeConfig = &dataprocpb.RuntimeConfig{} + } + batch.RuntimeConfig.Version = version + } + + req := &dataprocpb.CreateBatchRequest{ + Parent: fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()), + Batch: batch, + } + + op, err := client.CreateBatch(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to create batch: %w", err) + } + + meta, err := op.Metadata() + if err != nil { + return nil, fmt.Errorf("failed to get create batch op metadata: %w", err) + } + + jsonBytes, err := protojson.Marshal(meta) + if err != nil { + return nil, fmt.Errorf("failed to marshal create batch op metadata to JSON: %w", err) + } + + var result map[string]any + if err := json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal create batch op metadata JSON: %w", err) + } + + projectID, location, batchID, err := common.ExtractBatchDetails(meta.GetBatch()) + if err != nil { + return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err) + } + consoleUrl := common.BatchConsoleURL(projectID, location, batchID) + logsUrl := common.BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{}) + + wrappedResult := map[string]any{ + "opMetadata": meta, + "consoleUrl": consoleUrl, + "logsUrl": logsUrl, + } + + return wrappedResult, nil +} + +func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t *Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t *Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t *Tool) Authorized(services []string) bool { + return tools.IsAuthorized(t.AuthRequired, services) +} + +func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t *Tool) ToConfig() tools.ToolConfig { + return t.originalConfig +} + +func (t *Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} diff --git a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go index a0e36574ae..913a8151e6 100644 --- a/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkcancelbatch/serverlesssparkcancelbatch.go @@ -19,10 +19,10 @@ import ( "fmt" "strings" + longrunning "cloud.google.com/go/longrunning/autogen" "cloud.google.com/go/longrunning/autogen/longrunningpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -43,6 +43,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetOperationsClient(context.Context) (*longrunning.OperationsClient, error) + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -61,16 +67,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Cancels a running Serverless Spark (aka Dataproc Serverless) batch operation. Note that the batch state will not change immediately after the tool returns; it can take a minute or so for the cancellation to be reflected." @@ -89,7 +85,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return &Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -99,17 +94,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters } // Invoke executes the tool's operation. -func (t *Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client, err := t.Source.GetOperationsClient(ctx) +func (t *Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + client, err := source.GetOperationsClient(ctx) if err != nil { return nil, fmt.Errorf("failed to get operations client: %w", err) } @@ -125,7 +122,7 @@ func (t *Tool) Invoke(ctx context.Context, params parameters.ParamValues, access } req := &longrunningpb.CancelOperationRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", t.Source.Project, t.Source.Location, operation), + Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", source.GetProject(), source.GetLocation(), operation), } err = client.CancelOperation(ctx, req) @@ -152,15 +149,15 @@ func (t *Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t *Tool) RequiresClientAuthorization() bool { +func (t *Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t *Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go new file mode 100644 index 0000000000..d07fa01b92 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch.go @@ -0,0 +1,92 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatepysparkbatch + +import ( + "context" + "fmt" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/createbatch" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind = "serverless-spark-create-pyspark-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + baseCfg, err := createbatch.NewConfig(ctx, name, decoder) + if err != nil { + return nil, err + } + return Config{baseCfg}, nil +} + +type Config struct { + createbatch.Config +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + return createbatch.NewTool(cfg.Config, cfg, srcs, &PySparkBatchBuilder{}) +} + +type PySparkBatchBuilder struct{} + +func (b *PySparkBatchBuilder) Parameters() parameters.Parameters { + return parameters.Parameters{ + parameters.NewStringParameterWithRequired("mainFile", "The path to the main Python file, as a gs://... URI.", true), + parameters.NewArrayParameterWithRequired("args", "Optional. A list of arguments passed to the main file.", false, parameters.NewStringParameter("arg", "An argument.")), + parameters.NewStringParameterWithRequired("version", "Optional. The Serverless runtime version to execute with.", false), + } +} + +func (b *PySparkBatchBuilder) BuildBatch(params parameters.ParamValues) (*dataproc.Batch, error) { + paramMap := params.AsMap() + + mainFile := paramMap["mainFile"].(string) + + batch := &dataproc.Batch{ + BatchConfig: &dataproc.Batch_PysparkBatch{ + PysparkBatch: &dataproc.PySparkBatch{ + MainPythonFileUri: mainFile, + }, + }, + } + + if args, ok := paramMap["args"].([]any); ok { + for _, arg := range args { + batch.GetPysparkBatch().Args = append(batch.GetPysparkBatch().Args, fmt.Sprintf("%v", arg)) + } + } + + return batch, nil +} diff --git a/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go new file mode 100644 index 0000000000..b1081228dc --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch/serverlesssparkcreatepysparkbatch_test.go @@ -0,0 +1,30 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not- use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatepysparkbatch_test + +import ( + "testing" + + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/createbatch" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatepysparkbatch" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/testutils" +) + +func TestParseFromYaml(t *testing.T) { + testutils.RunParseFromYAMLTests(t, "serverless-spark-create-pyspark-batch", func(c createbatch.Config) tools.ToolConfig { + return serverlesssparkcreatepysparkbatch.Config{Config: c} + }) +} diff --git a/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch.go b/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch.go new file mode 100644 index 0000000000..e16b1904e2 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch.go @@ -0,0 +1,113 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatesparkbatch + +import ( + "context" + "fmt" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/createbatch" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const kind = "serverless-spark-create-spark-batch" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + baseCfg, err := createbatch.NewConfig(ctx, name, decoder) + if err != nil { + return nil, err + } + return Config{baseCfg}, nil +} + +type Config struct { + createbatch.Config +} + +// validate interface +var _ tools.ToolConfig = Config{} + +// ToolConfigKind returns the unique name for this tool. +func (cfg Config) ToolConfigKind() string { + return kind +} + +// Initialize creates a new Tool instance. +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + return createbatch.NewTool(cfg.Config, cfg, srcs, &SparkBatchBuilder{}) +} + +type SparkBatchBuilder struct{} + +func (b *SparkBatchBuilder) Parameters() parameters.Parameters { + return parameters.Parameters{ + parameters.NewStringParameterWithRequired("mainJarFile", "Optional. The gs:// URI of the jar file that contains the main class. Exactly one of mainJarFile or mainClass must be specified.", false), + parameters.NewStringParameterWithRequired("mainClass", "Optional. The name of the driver's main class. Exactly one of mainJarFile or mainClass must be specified.", false), + parameters.NewArrayParameterWithRequired("jarFiles", "Optional. A list of gs:// URIs of jar files to add to the CLASSPATHs of the Spark driver and tasks.", false, parameters.NewStringParameter("jarFile", "A jar file URI.")), + parameters.NewArrayParameterWithRequired("args", "Optional. A list of arguments passed to the driver.", false, parameters.NewStringParameter("arg", "An argument.")), + parameters.NewStringParameterWithRequired("version", "Optional. The Serverless runtime version to execute with.", false), + } +} + +func (b *SparkBatchBuilder) BuildBatch(params parameters.ParamValues) (*dataproc.Batch, error) { + paramMap := params.AsMap() + + mainJar, _ := paramMap["mainJarFile"].(string) + mainClass, _ := paramMap["mainClass"].(string) + + if mainJar == "" && mainClass == "" { + return nil, fmt.Errorf("must provide either mainJarFile or mainClass") + } + if mainJar != "" && mainClass != "" { + return nil, fmt.Errorf("cannot provide both mainJarFile and mainClass") + } + + sparkBatch := &dataproc.SparkBatch{} + if mainJar != "" { + sparkBatch.Driver = &dataproc.SparkBatch_MainJarFileUri{MainJarFileUri: mainJar} + } else { + sparkBatch.Driver = &dataproc.SparkBatch_MainClass{MainClass: mainClass} + } + + if jarFileUris, ok := paramMap["jarFiles"].([]any); ok { + for _, uri := range jarFileUris { + sparkBatch.JarFileUris = append(sparkBatch.JarFileUris, fmt.Sprintf("%v", uri)) + } + } else if mainClass != "" { + return nil, fmt.Errorf("jarFiles is required when mainClass is provided") + } + + if args, ok := paramMap["args"].([]any); ok { + for _, arg := range args { + sparkBatch.Args = append(sparkBatch.Args, fmt.Sprintf("%v", arg)) + } + } + + return &dataproc.Batch{ + BatchConfig: &dataproc.Batch_SparkBatch{ + SparkBatch: sparkBatch, + }, + }, nil +} diff --git a/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch_test.go b/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch_test.go new file mode 100644 index 0000000000..92d6b07f14 --- /dev/null +++ b/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch/serverlesssparkcreatesparkbatch_test.go @@ -0,0 +1,30 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package serverlesssparkcreatesparkbatch_test + +import ( + "testing" + + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/createbatch" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/serverlesssparkcreatesparkbatch" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/testutils" +) + +func TestParseFromYaml(t *testing.T) { + testutils.RunParseFromYAMLTests(t, "serverless-spark-create-spark-batch", func(c createbatch.Config) tools.ToolConfig { + return serverlesssparkcreatesparkbatch.Config{Config: c} + }) +} diff --git a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go index 4970dc1217..aebec7c9e4 100644 --- a/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go +++ b/internal/tools/serverlessspark/serverlesssparkgetbatch/serverlesssparkgetbatch.go @@ -20,11 +20,12 @@ import ( "fmt" "strings" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/protobuf/encoding/protojson" ) @@ -45,6 +46,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -63,16 +70,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Gets a Serverless Spark (aka Dataproc Serverless) batch" @@ -91,7 +88,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -101,17 +97,19 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters } // Invoke executes the tool's operation. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client := t.Source.GetBatchControllerClient() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + client := source.GetBatchControllerClient() paramMap := params.AsMap() name, ok := paramMap["name"].(string) @@ -124,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } req := &dataprocpb.GetBatchRequest{ - Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", t.Source.Project, t.Source.Location, name), + Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", source.GetProject(), source.GetLocation(), name), } batchPb, err := client.GetBatch(ctx, req) @@ -142,9 +140,23 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err) } - return result, nil -} + consoleUrl, err := common.BatchConsoleURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating console url: %v", err) + } + logsUrl, err := common.BatchLogsURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating logs url: %v", err) + } + wrappedResult := map[string]any{ + "consoleUrl": consoleUrl, + "logsUrl": logsUrl, + "batch": result, + } + + return wrappedResult, nil +} func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { return parameters.ParseParams(t.Parameters, data, claims) } @@ -161,15 +173,15 @@ func (t Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t Tool) RequiresClientAuthorization() bool { +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go index 0dddfd8db8..bc8bea2caa 100644 --- a/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go +++ b/internal/tools/serverlessspark/serverlesssparklistbatches/serverlesssparklistbatches.go @@ -19,11 +19,12 @@ import ( "fmt" "time" + dataproc "cloud.google.com/go/dataproc/v2/apiv1" "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/serverlessspark" "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/common" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" ) @@ -44,6 +45,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T return actual, nil } +type compatibleSource interface { + GetBatchControllerClient() *dataproc.BatchControllerClient + GetProject() string + GetLocation() string +} + type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -62,16 +69,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize creates a new Tool instance. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("source %q not found", cfg.Source) - } - - ds, ok := rawS.(*serverlessspark.Source) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be `%s`", kind, serverlessspark.SourceKind) - } - desc := cfg.Description if desc == "" { desc = "Lists available Serverless Spark (aka Dataproc Serverless) batches" @@ -92,7 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) return Tool{ Config: cfg, - Source: ds, manifest: tools.Manifest{Description: desc, Parameters: allParameters.Manifest()}, mcpManifest: mcpManifest, Parameters: allParameters, @@ -102,9 +98,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) // Tool is the implementation of the tool. type Tool struct { Config - - Source *serverlessspark.Source - manifest tools.Manifest mcpManifest tools.McpManifest Parameters parameters.Parameters @@ -124,13 +117,20 @@ type Batch struct { Creator string `json:"creator"` CreateTime string `json:"createTime"` Operation string `json:"operation"` + ConsoleURL string `json:"consoleUrl"` + LogsURL string `json:"logsUrl"` } // Invoke executes the tool's operation. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { - client := t.Source.GetBatchControllerClient() +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } - parent := fmt.Sprintf("projects/%s/locations/%s", t.Source.Project, t.Source.Location) + client := source.GetBatchControllerClient() + + parent := fmt.Sprintf("projects/%s/locations/%s", source.GetProject(), source.GetLocation()) req := &dataprocpb.ListBatchesRequest{ Parent: parent, OrderBy: "create_time desc", @@ -159,15 +159,26 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("failed to list batches: %w", err) } - batches := ToBatches(batchPbs) + batches, err := ToBatches(batchPbs) + if err != nil { + return nil, err + } return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil } // ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs. -func ToBatches(batchPbs []*dataprocpb.Batch) []Batch { +func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) { batches := make([]Batch, 0, len(batchPbs)) for _, batchPb := range batchPbs { + consoleUrl, err := common.BatchConsoleURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating console url: %v", err) + } + logsUrl, err := common.BatchLogsURLFromProto(batchPb) + if err != nil { + return nil, fmt.Errorf("error generating logs url: %v", err) + } batch := Batch{ Name: batchPb.Name, UUID: batchPb.Uuid, @@ -175,10 +186,12 @@ func ToBatches(batchPbs []*dataprocpb.Batch) []Batch { Creator: batchPb.Creator, CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339), Operation: batchPb.Operation, + ConsoleURL: consoleUrl, + LogsURL: logsUrl, } batches = append(batches, batch) } - return batches + return batches, nil } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { @@ -197,15 +210,15 @@ func (t Tool) Authorized(services []string) bool { return tools.IsAuthorized(t.AuthRequired, services) } -func (t Tool) RequiresClientAuthorization() bool { +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { // Client OAuth not supported, rely on ADCs. - return false + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/serverlessspark/testutils/testutils.go b/internal/tools/serverlessspark/testutils/testutils.go new file mode 100644 index 0000000000..a8a2cfc5a3 --- /dev/null +++ b/internal/tools/serverlessspark/testutils/testutils.go @@ -0,0 +1,149 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import ( + "fmt" + "strings" + "testing" + + dataproc "cloud.google.com/go/dataproc/v2/apiv1/dataprocpb" + "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/serverlessspark/createbatch" + "google.golang.org/protobuf/testing/protocmp" +) + +// RunParseFromYAMLTests runs a suite of tests for parsing tool configurations from YAML. +func RunParseFromYAMLTests(t *testing.T, kind string, newConfig func(c createbatch.Config) tools.ToolConfig) { + t.Helper() + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + tcs := []struct { + desc string + in string + want server.ToolConfigs + wantErr string + }{ + { + desc: "basic example", + in: fmt.Sprintf(` + tools: + example_tool: + kind: %s + source: my-instance + description: some description + `, kind), + want: server.ToolConfigs{ + "example_tool": newConfig(createbatch.Config{ + Name: "example_tool", + Kind: kind, + Source: "my-instance", + Description: "some description", + AuthRequired: []string{}, + }), + }, + }, + { + desc: "detailed config", + in: fmt.Sprintf(` + tools: + example_tool: + kind: %s + source: my-instance + description: some description + runtimeConfig: + properties: + "spark.driver.memory": "1g" + environmentConfig: + executionConfig: + networkUri: "my-network" + `, kind), + want: server.ToolConfigs{ + "example_tool": newConfig(createbatch.Config{ + Name: "example_tool", + Kind: kind, + Source: "my-instance", + Description: "some description", + RuntimeConfig: &dataproc.RuntimeConfig{ + Properties: map[string]string{"spark.driver.memory": "1g"}, + }, + EnvironmentConfig: &dataproc.EnvironmentConfig{ + ExecutionConfig: &dataproc.ExecutionConfig{ + Network: &dataproc.ExecutionConfig_NetworkUri{NetworkUri: "my-network"}, + }, + }, + AuthRequired: []string{}, + }), + }, + }, + { + desc: "invalid runtime config", + in: fmt.Sprintf(` + tools: + example_tool: + kind: %s + source: my-instance + description: some description + runtimeConfig: + invalidField: true + `, kind), + wantErr: "unmarshal runtimeConfig", + }, + { + desc: "invalid environment config", + in: fmt.Sprintf(` + tools: + example_tool: + kind: %s + source: my-instance + description: some description + environmentConfig: + invalidField: true + `, kind), + wantErr: "unmarshal environmentConfig", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got, yaml.Strict()) + if tc.wantErr != "" { + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error to contain %q, got %q", tc.wantErr, err) + } + return + } + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + + if diff := cmp.Diff(tc.want, got.Tools, protocmp.Transform()); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go index d0a891b815..7ab352b195 100644 --- a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go +++ b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/singlestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" @@ -48,11 +47,6 @@ type compatibleSource interface { SingleStorePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &singlestore.Source{} - -var compatibleSources = [...]string{singlestore.SourceKind} - // Config represents the configuration for the singlestore-execute-sql tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { // Initialize sets up the Tool using the provided sources map. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.SingleStorePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,7 +88,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -117,7 +97,12 @@ func (t Tool) ToConfig() tools.ToolConfig { } // Invoke executes the provided SQL query using the tool's database connection and returns the results. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -131,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.SingleStorePool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -199,10 +184,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/singlestore/singlestoresql/singlestoresql.go b/internal/tools/singlestore/singlestoresql/singlestoresql.go index 4a7047d27f..55adfe2dbf 100644 --- a/internal/tools/singlestore/singlestoresql/singlestoresql.go +++ b/internal/tools/singlestore/singlestoresql/singlestoresql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/singlestore" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { SingleStorePool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &singlestore.Source{} - -var compatibleSources = [...]string{singlestore.SourceKind} - // Config defines the configuration for a SingleStore SQL tool. type Config struct { Name string `yaml:"name" validate:"required"` @@ -85,18 +79,6 @@ func (cfg Config) ToolConfigKind() string { // tools.Tool - the initialized tool instance. // error - an error if the source is missing, incompatible, or setup fails. func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -108,7 +90,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.SingleStorePool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -122,7 +103,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Pool *sql.DB manifest tools.Manifest mcpManifest tools.McpManifest } @@ -145,7 +125,12 @@ func (t Tool) ToConfig() tools.ToolConfig { // Returns: // - A slice of maps, where each map represents a row with column names as keys. // - An error if template resolution, parameter extraction, query execution, or result processing fails. -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -158,7 +143,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.SingleStorePool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -226,10 +211,10 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go index 5cd1f1230e..f0c4ce2460 100644 --- a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go +++ b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go @@ -21,7 +21,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -72,18 +66,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -93,8 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -107,8 +87,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -137,7 +115,12 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return out, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -156,10 +139,10 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT stmt := spanner.Statement{SQL: sql} if t.ReadOnly { - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, opErr = processRows(iter) } else { - _, opErr = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { var err error iter := txn.Query(ctx, stmt) results, err = processRows(iter) @@ -193,14 +176,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go index 434fe93d18..b9e94408e2 100644 --- a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go +++ b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go @@ -23,7 +23,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,23 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - - // verify the dialect is GoogleSQL - if strings.ToLower(s.DatabaseDialect()) != "googlesql" { - return nil, fmt.Errorf("invalid source dialect for %q tool: source dialect must be GoogleSQL", kind) - } - // Define parameters for the tool allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault( @@ -104,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) description := cfg.Description if description == "" { - description = "Lists detailed graph schema information (node tables, edge tables, labels and property declarations) as JSON for user-created graphs. Filters by a comma-separated list of names. If names are omitted, lists all graphs in user schemas." + description = "Lists detailed graph schema information (node tables, edge tables, labels and property declarations) as JSON for user-created graphs. Filters by a comma-separated list of graph names. If names are omitted, lists all graphs. The output can be 'simple' (graph names only) or 'detailed' (full schema)." } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) @@ -112,8 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -126,8 +101,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -165,7 +138,17 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return out, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + + // Check dialect here at RUNTIME instead of startup + if strings.ToLower(source.DatabaseDialect()) != "googlesql" { + return nil, fmt.Errorf("operation not supported: The 'spanner-list-graphs' tool is only available for GoogleSQL dialect databases. Your current database dialect is '%s'", source.DatabaseDialect()) + } + paramsMap := params.AsMap() graphNames, _ := paramsMap["graph_names"].(string) @@ -185,7 +168,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Execute the query (read-only) - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, err := processRows(iter) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -210,16 +193,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } // GoogleSQL statement for listing graphs diff --git a/internal/tools/spanner/spannerlisttables/spannerlisttables.go b/internal/tools/spanner/spannerlisttables/spannerlisttables.go index 24c054e454..bd41479fed 100644 --- a/internal/tools/spanner/spannerlisttables/spannerlisttables.go +++ b/internal/tools/spanner/spannerlisttables/spannerlisttables.go @@ -23,7 +23,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -50,11 +49,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - // Define parameters for the tool allParameters := parameters.Parameters{ parameters.NewStringParameterWithDefault( @@ -99,7 +81,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) description := cfg.Description if description == "" { - description = "Lists detailed schema information (object type, columns, constraints, indexes) as JSON for user-created tables. Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas." + description = "Lists detailed schema information (object type, columns, constraints, indexes) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas. The output can be 'simple' (table names only) or 'detailed' (full schema)." } mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil) @@ -107,8 +89,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -121,15 +101,13 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } // processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. func processRows(iter *spanner.RowIterator) ([]any, error) { - var out []any + out := []any{} defer iter.Stop() for { @@ -160,8 +138,8 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return out, nil } -func (t Tool) getStatement() string { - switch strings.ToLower(t.dialect) { +func (t Tool) getStatement(source compatibleSource) string { + switch strings.ToLower(source.DatabaseDialect()) { case "postgresql": return postgresqlStatement case "googlesql": @@ -172,11 +150,16 @@ func (t Tool) getStatement() string { } } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() // Get the appropriate SQL statement based on dialect - statement := t.getStatement() + statement := t.getStatement(source) // Prepare parameters based on dialect var stmtParams map[string]interface{} @@ -187,7 +170,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT outputFormat = "detailed" } - switch strings.ToLower(t.dialect) { + switch strings.ToLower(source.DatabaseDialect()) { case "postgresql": // PostgreSQL uses positional parameters ($1, $2) stmtParams = map[string]interface{}{ @@ -202,7 +185,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT "output_format": outputFormat, } default: - return nil, fmt.Errorf("unsupported dialect: %s", t.dialect) + return nil, fmt.Errorf("unsupported dialect: %s", source.DatabaseDialect()) } stmt := spanner.Statement{ @@ -211,7 +194,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Execute the query (read-only) - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, err := processRows(iter) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) @@ -236,16 +219,16 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } // PostgreSQL statement for listing tables diff --git a/internal/tools/spanner/spannersql/spannersql.go b/internal/tools/spanner/spannersql/spannersql.go index 9c5deebc0e..d1b7c1ab54 100644 --- a/internal/tools/spanner/spannersql/spannersql.go +++ b/internal/tools/spanner/spannersql/spannersql.go @@ -22,7 +22,6 @@ import ( "cloud.google.com/go/spanner" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "google.golang.org/api/iterator" @@ -49,11 +48,6 @@ type compatibleSource interface { DatabaseDialect() string } -// validate compatible sources are still compatible -var _ compatibleSource = &spannerdb.Source{} - -var compatibleSources = [...]string{spannerdb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -74,18 +68,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -97,8 +79,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Client: s.SpannerClient(), - dialect: s.DatabaseDialect(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -111,8 +91,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config AllParams parameters.Parameters `yaml:"allParams"` - Client *spanner.Client - dialect string manifest tools.Manifest mcpManifest tools.McpManifest } @@ -152,7 +130,12 @@ func processRows(iter *spanner.RowIterator) ([]any, error) { return out, nil } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -187,7 +170,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT newParams[i] = parameters.ParamValue{Name: name, Value: value} } - mapParams, err := getMapParams(newParams, t.dialect) + mapParams, err := getMapParams(newParams, source.DatabaseDialect()) if err != nil { return nil, fmt.Errorf("fail to get map params: %w", err) } @@ -200,10 +183,10 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } if t.ReadOnly { - iter := t.Client.Single().Query(ctx, stmt) + iter := source.SpannerClient().Single().Query(ctx, stmt) results, opErr = processRows(iter) } else { - _, opErr = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { iter := txn.Query(ctx, stmt) results, err = processRows(iter) if err != nil { @@ -236,14 +219,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go index 113d1e6d84..e2c03a224a 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/sqlite" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/orderedmap" @@ -49,11 +48,6 @@ type compatibleSource interface { SQLiteDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &sqlite.Source{} - -var compatibleSources = [...]string{sqlite.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) @@ -90,7 +72,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - DB: s.SQLiteDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,14 +83,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - DB *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sql, ok := params.AsMap()["sql"].(string) if !ok { return nil, fmt.Errorf("missing or invalid 'sql' parameter") @@ -125,7 +109,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.DB.QueryContext(ctx, sql) + results, err := source.SQLiteDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -201,14 +185,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go index da4a0a7139..63079a883e 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql_test.go @@ -15,19 +15,13 @@ package sqliteexecutesql_test import ( - "context" - "database/sql" - "reflect" "testing" yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" - "github.com/googleapis/genai-toolbox/internal/util/parameters" _ "modernc.org/sqlite" ) @@ -81,251 +75,3 @@ func TestParseFromYamlExecuteSql(t *testing.T) { } } - -func setupTestDB(t *testing.T) *sql.DB { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("Failed to open in-memory database: %v", err) - } - return db -} - -func TestTool_Invoke(t *testing.T) { - ctx, err := testutils.ContextWithNewLogger() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - type fields struct { - Name string - Kind string - AuthRequired []string - Parameters parameters.Parameters - DB *sql.DB - } - type args struct { - ctx context.Context - params parameters.ParamValues - accessToken tools.AccessToken - } - tests := []struct { - name string - fields fields - args args - want any - wantErr bool - }{ - { - name: "create table", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "insert data", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER); INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "select data", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER); INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"); err != nil { - t.Fatalf("Failed to set up database for select: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM users"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(1)}, - {Name: "name", Value: "Alice"}, - {Name: "age", Value: int64(30)}, - }, - }, - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(2)}, - {Name: "name", Value: "Bob"}, - {Name: "age", Value: int64(25)}, - }, - }, - }, - wantErr: false, - }, - { - name: "drop table", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"); err != nil { - t.Fatalf("Failed to set up database for drop: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "DROP TABLE users"}, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "invalid sql", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM non_existent_table"}, - }, - }, - want: nil, - wantErr: true, - }, - { - name: "empty sql", - fields: fields{ - DB: setupTestDB(t), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: ""}, - }, - }, - want: nil, - wantErr: true, - }, - { - name: "data types", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE data_types (id INTEGER PRIMARY KEY, null_col TEXT, blob_col BLOB)"); err != nil { - t.Fatalf("Failed to set up database for data types: %v", err) - } - if _, err := db.Exec("INSERT INTO data_types (id, null_col, blob_col) VALUES (1, NULL, ?)", []byte{1, 2, 3}); err != nil { - t.Fatalf("Failed to insert data for data types: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT * FROM data_types"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "id", Value: int64(1)}, - {Name: "null_col", Value: nil}, - {Name: "blob_col", Value: []byte{1, 2, 3}}, - }, - }, - }, - wantErr: false, - }, - { - name: "join operation", - fields: fields{ - DB: func() *sql.DB { - db := setupTestDB(t) - if _, err := db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)"); err != nil { - t.Fatalf("Failed to set up database for join: %v", err) - } - if _, err := db.Exec("INSERT INTO users (id, name, age) VALUES (1, 'Alice', 30), (2, 'Bob', 25)"); err != nil { - t.Fatalf("Failed to insert data for join: %v", err) - } - if _, err := db.Exec("CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER, item TEXT)"); err != nil { - t.Fatalf("Failed to set up database for join: %v", err) - } - if _, err := db.Exec("INSERT INTO orders (id, user_id, item) VALUES (1, 1, 'Laptop'), (2, 2, 'Keyboard')"); err != nil { - t.Fatalf("Failed to insert data for join: %v", err) - } - return db - }(), - }, - args: args{ - ctx: ctx, - params: []parameters.ParamValue{ - {Name: "sql", Value: "SELECT u.name, o.item FROM users u JOIN orders o ON u.id = o.user_id"}, - }, - }, - want: []any{ - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "name", Value: "Alice"}, - {Name: "item", Value: "Laptop"}, - }, - }, - orderedmap.Row{ - Columns: []orderedmap.Column{ - {Name: "name", Value: "Bob"}, - {Name: "item", Value: "Keyboard"}, - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tr := sqliteexecutesql.Tool{ - Config: sqliteexecutesql.Config{ - Name: tt.fields.Name, - Kind: tt.fields.Kind, - AuthRequired: tt.fields.AuthRequired, - }, - Parameters: tt.fields.Parameters, - DB: tt.fields.DB, - } - got, err := tr.Invoke(tt.args.ctx, tt.args.params, tt.args.accessToken) - if (err != nil) != tt.wantErr { - t.Errorf("Tool.Invoke() error = %v, wantErr %v", err, tt.wantErr) - return - } - isEqual := false - if got != nil && len(got.([]any)) == 0 && len(tt.want.([]any)) == 0 { - isEqual = true // Special case for empty slices, since DeepEqual returns false - } else { - isEqual = reflect.DeepEqual(got, tt.want) - } - - if !isEqual { - t.Errorf("Tool.Invoke() = %+v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql.go b/internal/tools/sqlite/sqlitesql/sqlitesql.go index 39f2e5fb68..e715252dc4 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/sqlite" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { SQLiteDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &sqlite.Source{} - -var compatibleSources = [...]string{sqlite.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.SQLiteDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -126,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Execute the SQL query with parameters - rows, err := t.Db.QueryContext(ctx, newStatement, newParams.AsSlice()...) + rows, err := source.SQLiteDB().QueryContext(ctx, newStatement, newParams.AsSlice()...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -200,14 +184,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql_test.go b/internal/tools/sqlite/sqlitesql/sqlitesql_test.go index 47befc6d57..eea6fddf4f 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql_test.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql_test.go @@ -15,16 +15,12 @@ package sqlitesql_test import ( - "context" - "database/sql" - "reflect" "testing" yaml "github.com/goccy/go-yaml" "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/server" "github.com/googleapis/genai-toolbox/internal/testutils" - "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" "github.com/googleapis/genai-toolbox/internal/util/parameters" _ "modernc.org/sqlite" @@ -179,148 +175,3 @@ func TestParseFromYamlWithTemplateSqlite(t *testing.T) { }) } } - -func setupTestDB(t *testing.T) *sql.DB { - db, err := sql.Open("sqlite", ":memory:") - if err != nil { - t.Fatalf("Failed to open in-memory database: %v", err) - } - - createTable := ` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - name TEXT, - age INTEGER - );` - if _, err := db.Exec(createTable); err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - insertData := ` - INSERT INTO users (id, name, age) VALUES - (1, 'Alice', 30), - (2, 'Bob', 25);` - if _, err := db.Exec(insertData); err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - - return db -} - -func TestTool_Invoke(t *testing.T) { - type fields struct { - Name string - Kind string - AuthRequired []string - Parameters parameters.Parameters - TemplateParameters parameters.Parameters - AllParams parameters.Parameters - Db *sql.DB - Statement string - } - type args struct { - ctx context.Context - params parameters.ParamValues - accessToken tools.AccessToken - } - tests := []struct { - name string - fields fields - args args - want any - wantErr bool - }{ - { - name: "simple select", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM users", - }, - args: args{ - ctx: context.Background(), - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - map[string]any{"id": int64(2), "name": "Bob", "age": int64(25)}, - }, - wantErr: false, - }, - { - name: "select with parameter", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM users WHERE name = ?", - Parameters: []parameters.Parameter{ - parameters.NewStringParameter("name", "user name"), - }, - }, - args: args{ - ctx: context.Background(), - params: []parameters.ParamValue{ - {Name: "name", Value: "Alice"}, - }, - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - }, - wantErr: false, - }, - { - name: "select with template parameter", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM {{.tableName}}", - TemplateParameters: []parameters.Parameter{ - parameters.NewStringParameter("tableName", "table name"), - }, - }, - args: args{ - ctx: context.Background(), - params: []parameters.ParamValue{ - {Name: "tableName", Value: "users"}, - }, - }, - want: []any{ - map[string]any{"id": int64(1), "name": "Alice", "age": int64(30)}, - map[string]any{"id": int64(2), "name": "Bob", "age": int64(25)}, - }, - wantErr: false, - }, - { - name: "invalid sql", - fields: fields{ - Db: setupTestDB(t), - Statement: "SELECT * FROM non_existent_table", - }, - args: args{ - ctx: context.Background(), - }, - want: nil, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tr := sqlitesql.Tool{ - Config: sqlitesql.Config{ - Name: tt.fields.Name, - Kind: tt.fields.Kind, - AuthRequired: tt.fields.AuthRequired, - Statement: tt.fields.Statement, - Parameters: tt.fields.Parameters, - TemplateParameters: tt.fields.TemplateParameters, - }, - AllParams: tt.fields.AllParams, - Db: tt.fields.Db, - } - got, err := tr.Invoke(tt.args.ctx, tt.args.params, tt.args.accessToken) - if (err != nil) != tt.wantErr { - t.Errorf("Tool.Invoke() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Tool.Invoke() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go index 38eecca684..b452de841d 100644 --- a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go +++ b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/tidb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" @@ -47,11 +46,6 @@ type compatibleSource interface { TiDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &tidb.Source{} - -var compatibleSources = [...]string{tidb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,18 +62,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") params := parameters.Parameters{sqlParameter} @@ -89,7 +71,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Pool: s.TiDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -101,14 +82,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Pool *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() sql, ok := paramsMap["sql"].(string) if !ok { @@ -122,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := t.Pool.QueryContext(ctx, sql) + results, err := source.TiDBPool().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -194,14 +178,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/tidb/tidbsql/tidbsql.go b/internal/tools/tidb/tidbsql/tidbsql.go index 6148978e7e..f35d0a61db 100644 --- a/internal/tools/tidb/tidbsql/tidbsql.go +++ b/internal/tools/tidb/tidbsql/tidbsql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/tidb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -47,11 +46,6 @@ type compatibleSource interface { TiDBPool() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &tidb.Source{} - -var compatibleSources = [...]string{tidb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -71,18 +65,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -94,7 +76,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.TiDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -106,14 +87,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -126,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } sliceParams := newParams.AsSlice() - results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.TiDBPool().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -206,14 +190,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/tools.go b/internal/tools/tools.go index 6985dadd24..7283655f0c 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -85,14 +85,21 @@ func (token AccessToken) ParseBearerToken() (string, error) { } type Tool interface { - Invoke(context.Context, parameters.ParamValues, AccessToken) (any, error) + Invoke(context.Context, SourceProvider, parameters.ParamValues, AccessToken) (any, error) ParseParams(map[string]any, map[string]map[string]any) (parameters.ParamValues, error) Manifest() Manifest McpManifest() McpManifest Authorized([]string) bool - RequiresClientAuthorization() bool + RequiresClientAuthorization(SourceProvider) (bool, error) ToConfig() ToolConfig - GetAuthTokenHeaderName() string + GetAuthTokenHeaderName(SourceProvider) (string, error) +} + +// SourceProvider defines the minimal view of the server.ResourceManager +// that the Tool package needs. +// This is implemented to prevent import cycles. +type SourceProvider interface { + GetSource(sourceName string) (sources.Source, bool) } // Manifest is the representation of tools sent to Client SDKs. @@ -150,3 +157,16 @@ func IsAuthorized(authRequiredSources []string, verifiedAuthServices []string) b } return false } + +func GetCompatibleSource[T any](resourceMgr SourceProvider, sourceName, toolName, toolKind string) (T, error) { + var zero T + s, ok := resourceMgr.GetSource(sourceName) + if !ok { + return zero, fmt.Errorf("unable to retrieve source %q for tool %q", sourceName, toolName) + } + source, ok := s.(T) + if !ok { + return zero, fmt.Errorf("invalid source for %q tool: source %q is not a compatible type", toolKind, sourceName) + } + return source, nil +} diff --git a/internal/tools/toolsets.go b/internal/tools/toolsets.go index 895e5989b4..b429ef5b19 100644 --- a/internal/tools/toolsets.go +++ b/internal/tools/toolsets.go @@ -46,9 +46,9 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool var toolset Toolset toolset.Name = t.Name if !IsValidName(toolset.Name) { - return toolset, fmt.Errorf("invalid toolset name: %s", t) + return toolset, fmt.Errorf("invalid toolset name: %s", toolset.Name) } - toolset.Tools = make([]*Tool, len(t.ToolNames)) + toolset.Tools = make([]*Tool, 0, len(t.ToolNames)) toolset.Manifest = ToolsetManifest{ ServerVersion: serverVersion, ToolsManifest: make(map[string]Manifest), @@ -56,7 +56,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool for _, toolName := range t.ToolNames { tool, ok := toolsMap[toolName] if !ok { - return toolset, fmt.Errorf("tool does not exist: %s", t) + return toolset, fmt.Errorf("tool does not exist: %s", toolName) } toolset.Tools = append(toolset.Tools, &tool) toolset.Manifest.ToolsManifest[toolName] = tool.Manifest() diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go index cb1ad24f31..f9f396bd03 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/trino" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { TrinoDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &trino.Source{} - -var compatibleSources = [...]string{trino.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +61,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - sqlParameter := parameters.NewStringParameter("sql", "The SQL query to execute against the Trino database.") params := parameters.Parameters{sqlParameter} @@ -88,7 +70,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - Db: s.TrinoDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -100,21 +81,24 @@ var _ tools.Tool = Tool{} type Tool struct { Config - Parameters parameters.Parameters `yaml:"parameters"` - - Db *sql.DB + Parameters parameters.Parameters `yaml:"parameters"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + sliceParams := params.AsSlice() sql, ok := sliceParams[0].(string) if !ok { return nil, fmt.Errorf("unable to cast sql parameter: %v", sliceParams[0]) } - results, err := t.Db.QueryContext(ctx, sql) + results, err := source.TrinoDB().QueryContext(ctx, sql) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -179,14 +163,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/trino/trinosql/trinosql.go b/internal/tools/trino/trinosql/trinosql.go index 776f1509fe..7dd06d505c 100644 --- a/internal/tools/trino/trinosql/trinosql.go +++ b/internal/tools/trino/trinosql/trinosql.go @@ -21,7 +21,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/trino" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,11 +45,6 @@ type compatibleSource interface { TrinoDB() *sql.DB } -// validate compatible sources are still compatible -var _ compatibleSource = &trino.Source{} - -var compatibleSources = [...]string{trino.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -70,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, fmt.Errorf("unable to process parameters: %w", err) @@ -93,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Db: s.TrinoDB(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -105,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Db *sql.DB + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -123,7 +107,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Db.QueryContext(ctx, newStatement, sliceParams...) + results, err := source.TrinoDB().QueryContext(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -188,14 +172,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/utility/wait/wait.go b/internal/tools/utility/wait/wait.go index 43480dcf89..5b931ebcaf 100644 --- a/internal/tools/utility/wait/wait.go +++ b/internal/tools/utility/wait/wait.go @@ -80,7 +80,7 @@ type Tool struct { mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { paramsMap := params.AsMap() durationStr, ok := paramsMap["duration"].(string) @@ -114,14 +114,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return true } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/valkey/valkey.go b/internal/tools/valkey/valkey.go index 5d6d46a36f..8f9d90c264 100644 --- a/internal/tools/valkey/valkey.go +++ b/internal/tools/valkey/valkey.go @@ -19,7 +19,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - valkeysrc "github.com/googleapis/genai-toolbox/internal/sources/valkey" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/valkey-io/valkey-go" @@ -45,11 +44,6 @@ type compatibleSource interface { ValkeyClient() valkey.Client } -// validate compatible sources are still compatible -var _ compatibleSource = &valkeysrc.Source{} - -var compatibleSources = [...]string{valkeysrc.SourceKind, valkeysrc.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -68,24 +62,11 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, cfg.Parameters, nil) // finish tool setup t := Tool{ Config: cfg, - Client: s.ValkeyClient(), manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -97,13 +78,16 @@ var _ tools.Tool = Tool{} type Tool struct { Config - - Client valkey.Client manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + // Replace parameters commands, err := replaceCommandsParams(t.Commands, t.Parameters, params) if err != nil { @@ -114,7 +98,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT builtCmds := make(valkey.Commands, len(commands)) for i, cmd := range commands { - builtCmds[i] = t.Client.B().Arbitrary(cmd...).Build() + builtCmds[i] = source.ValkeyClient().B().Arbitrary(cmd...).Build() } if len(builtCmds) == 0 { @@ -122,7 +106,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT } // Execute commands - responses := t.Client.DoMulti(ctx, builtCmds...) + responses := source.ValkeyClient().DoMulti(ctx, builtCmds...) // Parse responses out := make([]any, len(t.Commands)) @@ -193,14 +177,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/tools/yugabytedbsql/yugabytedbsql.go b/internal/tools/yugabytedbsql/yugabytedbsql.go index 9bfb90e4ff..3b774ac366 100644 --- a/internal/tools/yugabytedbsql/yugabytedbsql.go +++ b/internal/tools/yugabytedbsql/yugabytedbsql.go @@ -20,7 +20,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/yugabytedb" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/yugabyte/pgx/v5/pgxpool" @@ -46,8 +45,6 @@ type compatibleSource interface { YugabyteDBPool() *pgxpool.Pool } -var compatibleSources = [...]string{yugabytedb.SourceKind} - type Config struct { Name string `yaml:"name" validate:"required"` Kind string `yaml:"kind" validate:"required"` @@ -67,18 +64,6 @@ func (cfg Config) ToolConfigKind() string { } func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { - // verify source exists - rawS, ok := srcs[cfg.Source] - if !ok { - return nil, fmt.Errorf("no source named %q configured", cfg.Source) - } - - // verify the source is compatible - s, ok := rawS.(compatibleSource) - if !ok { - return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) - } - allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) if err != nil { return nil, err @@ -90,7 +75,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, AllParams: allParameters, - Pool: s.YugabyteDBPool(), manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -102,14 +86,17 @@ var _ tools.Tool = Tool{} type Tool struct { Config - AllParams parameters.Parameters `yaml:"allParams"` - - Pool *pgxpool.Pool + AllParams parameters.Parameters `yaml:"allParams"` manifest tools.Manifest mcpManifest tools.McpManifest } -func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) + if err != nil { + return nil, err + } + paramsMap := params.AsMap() newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) if err != nil { @@ -121,7 +108,7 @@ func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessT return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := t.Pool.Query(ctx, newStatement, sliceParams...) + results, err := source.YugabyteDBPool().Query(ctx, newStatement, sliceParams...) if err != nil { return nil, fmt.Errorf("unable to execute query: %w", err) } @@ -165,14 +152,14 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool { return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) } -func (t Tool) RequiresClientAuthorization() bool { - return false +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil } func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func (t Tool) GetAuthTokenHeaderName() string { - return "Authorization" +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil } diff --git a/internal/util/util.go b/internal/util/util.go index 9b0f269ce7..657fe8bf29 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "net/http" "strings" "github.com/go-playground/validator/v10" @@ -119,6 +120,30 @@ func UserAgentFromContext(ctx context.Context) (string, error) { } } +type UserAgentRoundTripper struct { + userAgent string + next http.RoundTripper +} + +func NewUserAgentRoundTripper(ua string, next http.RoundTripper) *UserAgentRoundTripper { + return &UserAgentRoundTripper{ + userAgent: ua, + next: next, + } +} + +func (rt *UserAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + // create a deep copy of the request + newReq := req.Clone(req.Context()) + ua := newReq.Header.Get("User-Agent") + if ua == "" { + newReq.Header.Set("User-Agent", rt.userAgent) + } else { + newReq.Header.Set("User-Agent", ua+" "+rt.userAgent) + } + return rt.next.RoundTrip(newReq) +} + func NewStrictDecoder(v interface{}) (*yaml.Decoder, error) { b, err := yaml.Marshal(v) if err != nil { diff --git a/server.json b/server.json index 722bfb305e..fe2dfd9a82 100644 --- a/server.json +++ b/server.json @@ -14,11 +14,11 @@ "url": "https://github.com/googleapis/genai-toolbox", "source": "github" }, - "version": "0.21.0", + "version": "0.24.0", "packages": [ { "registryType": "oci", - "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.21.0", + "identifier": "us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:0.24.0", "transport": { "type": "streamable-http", "url": "http://{host}:{port}/mcp" diff --git a/tests/alloydb/alloydb_integration_test.go b/tests/alloydb/alloydb_integration_test.go index 0e9d64dfb8..5028f892e4 100644 --- a/tests/alloydb/alloydb_integration_test.go +++ b/tests/alloydb/alloydb_integration_test.go @@ -488,14 +488,12 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) { name string requestBody io.Reader wantContains string - wantCount int wantStatusCode int }{ { name: "list users success", requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])), wantContains: fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", vars["project"], vars["location"], vars["cluster"], AlloyDBUser), - wantCount: 3, // NOTE: If users are added or removed in the test project, update the number of users here must be updated for this test to pass wantStatusCode: http.StatusOK, }, { @@ -567,10 +565,6 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) { sort.Strings(got) - if len(got) != tc.wantCount { - t.Errorf("user count mismatch:\n got: %v\nwant: %v", len(got), tc.wantCount) - } - found := false for _, g := range got { if g == tc.wantContains { diff --git a/tests/alloydbpg/alloydb_pg_integration_test.go b/tests/alloydbpg/alloydb_pg_integration_test.go index 13e8288f4b..d7a903ac41 100644 --- a/tests/alloydbpg/alloydb_pg_integration_test.go +++ b/tests/alloydbpg/alloydb_pg_integration_test.go @@ -181,7 +181,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) { // Run Postgres prebuilt tool tests tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, AlloyDBPostgresUser) - tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) + tests.RunPostgresListViewsTest(t, ctx, pool) tests.RunPostgresListSchemasTest(t, ctx, pool) tests.RunPostgresListActiveQueriesTest(t, ctx, pool) tests.RunPostgresListAvailableExtensionsTest(t) @@ -195,6 +195,12 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) { tests.RunPostgresLongRunningTransactionsTest(t, ctx, pool) tests.RunPostgresListQueryStatsTest(t, ctx, pool) tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool) + tests.RunPostgresListTableStatsTest(t, ctx, pool) + tests.RunPostgresListPublicationTablesTest(t, ctx, pool) + tests.RunPostgresListTableSpacesTest(t) + tests.RunPostgresListPgSettingsTest(t, ctx, pool) + tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) + tests.RunPostgresListRolesTest(t, ctx, pool) } // Test connection with different IP type diff --git a/tests/clickhouse/clickhouse_integration_test.go b/tests/clickhouse/clickhouse_integration_test.go index 0e81402c6a..058e4d1b1a 100644 --- a/tests/clickhouse/clickhouse_integration_test.go +++ b/tests/clickhouse/clickhouse_integration_test.go @@ -15,9 +15,12 @@ package clickhouse import ( + "bytes" "context" "database/sql" + "encoding/json" "fmt" + "net/http" "os" "regexp" "strings" @@ -26,16 +29,9 @@ import ( _ "github.com/ClickHouse/clickhouse-go/v2" "github.com/google/uuid" - "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" "github.com/googleapis/genai-toolbox/internal/testutils" - clickhouseexecutesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" - clickhouselistdatabases "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" - clickhouselisttables "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" - clickhousesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/googleapis/genai-toolbox/tests" - "go.opentelemetry.io/otel/trace/noop" ) var ( @@ -384,150 +380,125 @@ func TestClickHouseSQLTool(t *testing.T) { t.Fatalf("Failed to insert test data: %v", err) } - t.Run("SimpleSelect", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-select", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test select query", - Statement: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - result, err := tool.Invoke(ctx, parameters.ParamValues{}, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 3 { - t.Errorf("Expected 3 results, got %d", len(resultSlice)) - } - }) - - t.Run("ParameterizedQuery", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-param-query", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test parameterized query", - Statement: fmt.Sprintf("SELECT * FROM %s WHERE age > ? ORDER BY id", tableName), - Parameters: parameters.Parameters{ - parameters.NewIntParameter("min_age", "Minimum age"), + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-select": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test select query", + "statement": fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), }, - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - params := parameters.ParamValues{ - {Name: "min_age", Value: 28}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results (Bob and Charlie), got %d", len(resultSlice)) - } - }) - - t.Run("EmptyResult", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-empty-result", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test query with no results", - Statement: fmt.Sprintf("SELECT * FROM %s WHERE id = ?", tableName), - Parameters: parameters.Parameters{ - parameters.NewIntParameter("id", "Record ID"), + "test-param-query": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test parameterized query", + "statement": fmt.Sprintf("SELECT * FROM %s WHERE age > ? ORDER BY id", tableName), + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("min_age", "Minimum age"), + }, }, - } + "test-empty-result": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test query with no results", + "statement": fmt.Sprintf("SELECT * FROM %s WHERE id = ?", tableName), + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("id", "Record ID"), + }, + }, + "test-invalid-sql": map[string]any{ + "kind": ClickHouseToolKind, + "source": "my-instance", + "description": "Test invalid SQL", + "statement": "SELEC * FROM nonexistent_table", // Typo in SELECT + }, + }, + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } - params := parameters.ParamValues{ - {Name: "id", Value: 999}, // Non-existent ID - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to invoke tool: %v", err) - } - - // ClickHouse returns empty slice for no results, not nil - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for non-existent record, got %d results", len(resultSlice)) + tcs := []struct { + name string + toolName string + requestBody []byte + resultSliceLen int + isErr bool + }{ + { + name: "SimpleSelect", + toolName: "test-select", + requestBody: []byte(`{}`), + resultSliceLen: 3, + }, + { + name: "ParameterizedQuery", + toolName: "test-param-query", + requestBody: []byte(`{"min_age": 28}`), + resultSliceLen: 2, + }, + { + name: "EmptyResult", + toolName: "test-empty-result", + requestBody: []byte(`{"id": 999}`), // non-existent id + resultSliceLen: 0, + }, + { + name: "InvalidSQL", + toolName: "test-invalid-sql", + requestBody: []byte(``), + isErr: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer(tc.requestBody), nil) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - } else if result != nil { - t.Errorf("Expected empty slice or nil result for empty query, got %v", result) - } - }) - t.Run("InvalidSQL", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-invalid-sql", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test invalid SQL", - Statement: "SELEC * FROM nonexistent_table", // Typo in SELECT - } + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + t.Logf("result is %s", got) - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } - _, err = tool.Invoke(ctx, parameters.ParamValues{}, "") - if err == nil { - t.Error("Expected error for invalid SQL, got nil") - } - - if !strings.Contains(err.Error(), "Syntax error") && !strings.Contains(err.Error(), "SELEC") { - t.Errorf("Expected syntax error message, got: %v", err) - } - }) + if len(res) != tc.resultSliceLen { + t.Errorf("Expected %d results, got %d", tc.resultSliceLen, len(res)) + } + }) + } t.Logf("✅ clickhouse-sql tool tests completed successfully") } @@ -545,224 +516,108 @@ func TestClickHouseExecuteSQLTool(t *testing.T) { tableName := "test_exec_sql_" + strings.ReplaceAll(uuid.New().String(), "-", "") - t.Run("CreateTable", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-create-table", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test create table", - } + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "execute-sql-tool": map[string]any{ + "kind": "clickhouse-execute-sql", + "source": "my-instance", + "description": "Test create table", + }, + }, + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + tcs := []struct { + name string + sql string + resultSliceLen int + isErr bool + }{ + { + name: "CreateTable", + sql: fmt.Sprintf(`CREATE TABLE %s (id UInt32, data String) ENGINE = Memory`, tableName), + resultSliceLen: 0, + }, + { + name: "InsertData", + sql: fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, 'test1'), (2, 'test2')", tableName), + resultSliceLen: 0, + }, + { + name: "SelectData", + sql: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + resultSliceLen: 2, + }, + { + name: "DropTable", + sql: fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName), + resultSliceLen: 0, + }, + { + name: "MissingSQL", + sql: "", + isErr: true, + }, - createSQL := fmt.Sprintf(` - CREATE TABLE %s ( - id UInt32, - data String - ) ENGINE = Memory - `, tableName) - - params := parameters.ParamValues{ - {Name: "sql", Value: createSQL}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to create table: %v", err) - } - - // CREATE TABLE should return nil or empty slice (no rows) - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for CREATE TABLE, got %d results", len(resultSlice)) + { + name: "SQLInjectionAttempt", + sql: "SELECT 1; DROP TABLE system.users; SELECT 2", + isErr: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + param := fmt.Sprintf(`{"sql": "%s"}`, tc.sql) + api := "http://127.0.0.1:5000/api/tool/execute-sql-tool/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(param)), nil) + if resp.StatusCode != http.StatusOK { + if tc.isErr { + return + } + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - } else if result != nil { - t.Errorf("Expected nil or empty slice for CREATE TABLE, got %v", result) - } - }) - - t.Run("InsertData", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-insert", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test insert data", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - insertSQL := fmt.Sprintf("INSERT INTO %s (id, data) VALUES (1, 'test1'), (2, 'test2')", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: insertSQL}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to insert data: %v", err) - } - - // INSERT should return nil or empty slice - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for INSERT, got %d results", len(resultSlice)) + if tc.isErr { + t.Fatalf("expecting an error from server") } - } else if result != nil { - t.Errorf("Expected nil or empty slice for INSERT, got %v", result) - } - }) - t.Run("SelectData", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-select", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test select data", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - selectSQL := fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: selectSQL}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to select data: %v", err) - } - - resultSlice, ok := result.([]any) - if !ok { - t.Fatalf("Expected result to be []any, got %T", result) - } - - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results, got %d", len(resultSlice)) - } - }) - - t.Run("DropTable", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-drop-table", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test drop table", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName) - params := parameters.ParamValues{ - {Name: "sql", Value: dropSQL}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to drop table: %v", err) - } - - // DROP TABLE should return nil or empty slice - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 0 { - t.Errorf("Expected empty result for DROP TABLE, got %d results", len(resultSlice)) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) + if err != nil { + t.Fatalf("error parsing response body") } - } else if result != nil { - t.Errorf("Expected nil or empty slice for DROP TABLE, got %v", result) - } - }) - t.Run("MissingSQL", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-missing-sql", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test missing SQL parameter", - } + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - // Pass empty SQL parameter - this should cause an error - params := parameters.ParamValues{ - {Name: "sql", Value: ""}, - } - - _, err = tool.Invoke(ctx, params, "") - if err == nil { - t.Error("Expected error for empty SQL parameter, got nil") - } else { - t.Logf("Got expected error for empty SQL parameter: %v", err) - } - }) - - t.Run("SQLInjectionAttempt", func(t *testing.T) { - toolConfig := clickhouseexecutesql.Config{ - Name: "test-sql-injection", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test SQL injection attempt", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - // Try to execute multiple statements (should fail or execute safely) - injectionSQL := "SELECT 1; DROP TABLE system.users; SELECT 2" - params := parameters.ParamValues{ - {Name: "sql", Value: injectionSQL}, - } - - _, err = tool.Invoke(ctx, params, "") - // This should either fail or only execute the first statement - // dont check the specific error as behavior may vary - _ = err // We're not checking the error intentionally - }) + if len(res) != tc.resultSliceLen { + t.Errorf("Expected %d results, got %d", tc.resultSliceLen, len(res)) + } + }) + } t.Logf("✅ clickhouse-execute-sql tool tests completed successfully") } @@ -778,6 +633,49 @@ func TestClickHouseEdgeCases(t *testing.T) { } defer pool.Close() + tableName := "test_nulls_" + strings.ReplaceAll(uuid.New().String(), "-", "") + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "execute-sql-tool": map[string]any{ + "kind": "clickhouse-execute-sql", + "source": "my-instance", + "description": "Test create table", + }, + "test-null-values": map[string]any{ + "kind": "clickhouse-sql", + "source": "my-instance", + "description": "Test null values", + "statement": fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + }, + "test-concurrent": map[string]any{ + "kind": "clickhouse-sql", + "source": "my-instance", + "description": "Test concurrent queries", + "statement": "SELECT number FROM system.numbers LIMIT ?", + "parameters": []parameters.Parameter{ + parameters.NewIntParameter("limit", "Limit"), + }, + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } t.Run("VeryLongQuery", func(t *testing.T) { // Create a very long but valid query var conditions []string @@ -786,42 +684,37 @@ func TestClickHouseEdgeCases(t *testing.T) { } longQuery := "SELECT 1 WHERE " + strings.Join(conditions, " AND ") - toolConfig := clickhouseexecutesql.Config{ - Name: "test-long-query", - Kind: "clickhouse-execute-sql", - Source: "test-clickhouse", - Description: "Test very long query", + api := "http://127.0.0.1:5000/api/tool/execute-sql-tool/invoke" + param := fmt.Sprintf(`{"sql": "%s"}`, longQuery) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(param)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{ - {Name: "sql", Value: longQuery}, + got, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") } - result, err := tool.Invoke(ctx, params, "") + var res []any + err = json.Unmarshal([]byte(got), &res) if err != nil { - t.Fatalf("Failed to execute long query: %v", err) + t.Fatalf("error parsing result") } // Should return [{1:1}] - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != 1 { - t.Errorf("Expected 1 result from long query, got %d", len(resultSlice)) - } + if len(res) != 1 { + t.Errorf("Expected 1 result from long query, got %d", len(res)) } }) t.Run("NullValues", func(t *testing.T) { - tableName := "test_nulls_" + strings.ReplaceAll(uuid.New().String(), "-", "") createSQL := fmt.Sprintf(` CREATE TABLE %s ( id UInt32, @@ -844,40 +737,35 @@ func TestClickHouseEdgeCases(t *testing.T) { t.Fatalf("Failed to insert null value: %v", err) } - toolConfig := clickhousesql.Config{ - Name: "test-null-values", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test null values", - Statement: fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName), + api := "http://127.0.0.1:5000/api/tool/test-null-values/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - result, err := tool.Invoke(ctx, parameters.ParamValues{}, "") - if err != nil { - t.Fatalf("Failed to select null values: %v", err) - } - - resultSlice, ok := result.([]any) + got, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []any, got %T", result) + t.Fatalf("unable to find result in response body") } - if len(resultSlice) != 2 { - t.Errorf("Expected 2 results, got %d", len(resultSlice)) + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Fatalf("error parsing result") + } + + if len(res) != 2 { + t.Errorf("Expected 2 result from long query, got %d", len(res)) } // Check that null is properly handled - if firstRow, ok := resultSlice[0].(map[string]any); ok { + if firstRow, ok := res[0].(map[string]any); ok { if _, hasNullableField := firstRow["nullable_field"]; !hasNullableField { t.Error("Expected nullable_field in result") } @@ -885,47 +773,38 @@ func TestClickHouseEdgeCases(t *testing.T) { }) t.Run("ConcurrentQueries", func(t *testing.T) { - toolConfig := clickhousesql.Config{ - Name: "test-concurrent", - Kind: "clickhouse-sql", - Source: "test-clickhouse", - Description: "Test concurrent queries", - Statement: "SELECT number FROM system.numbers LIMIT ?", - Parameters: parameters.Parameters{ - parameters.NewIntParameter("limit", "Limit"), - }, - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - // Run multiple queries concurrently done := make(chan bool, 5) for i := 0; i < 5; i++ { go func(n int) { defer func() { done <- true }() - params := parameters.ParamValues{ - {Name: "limit", Value: n + 1}, + params := fmt.Sprintf(`{"limit": %d}`, n+1) + api := "http://127.0.0.1:5000/api/tool/test-concurrent/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(params)), nil) + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - result, err := tool.Invoke(ctx, params, "") + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Errorf("Concurrent query %d failed: %v", n, err) - return + t.Errorf("error parsing response body") } - if resultSlice, ok := result.([]any); ok { - if len(resultSlice) != n+1 { - t.Errorf("Query %d: expected %d results, got %d", n, n+1, len(resultSlice)) - } + got, ok := body["result"].(string) + if !ok { + t.Errorf("unable to find result in response body") + } + + var res []any + err = json.Unmarshal([]byte(got), &res) + if err != nil { + t.Errorf("error parsing result") + } + + if len(res) != n+1 { + t.Errorf("Query %d: expected %d results, got %d", n, n+1, len(res)) } }(i) } @@ -939,25 +818,6 @@ func TestClickHouseEdgeCases(t *testing.T) { t.Logf("✅ Edge case tests completed successfully") } -func createMockSource(t *testing.T, pool *sql.DB) sources.Source { - config := clickhouse.Config{ - Host: ClickHouseHost, - Port: ClickHousePort, - Database: ClickHouseDatabase, - User: ClickHouseUser, - Password: ClickHousePass, - Protocol: ClickHouseProtocol, - Secure: false, - } - - source, err := config.Initialize(context.Background(), noop.NewTracerProvider().Tracer("")) - if err != nil { - t.Fatalf("Failed to initialize source: %v", err) - } - - return source -} - // getClickHouseSQLParamToolInfo returns statements and param for my-tool clickhouse-sql kind func getClickHouseSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id UInt32, name String) ENGINE = Memory", tableName) @@ -1036,44 +896,70 @@ func TestClickHouseListDatabasesTool(t *testing.T) { _, _ = pool.ExecContext(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s", testDBName)) }() + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-list-databases": map[string]any{ + "kind": "clickhouse-list-databases", + "source": "my-instance", + "description": "Test listing databases", + }, + "test-invalid-source": map[string]any{ + "kind": "clickhouse-list-databases", + "source": "non-existent-source", + "description": "Test with invalid source", + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + t.Run("ListDatabases", func(t *testing.T) { - toolConfig := clickhouselistdatabases.Config{ - Name: "test-list-databases", - Kind: "clickhouse-list-databases", - Source: "test-clickhouse", - Description: "Test listing databases", + api := "http://127.0.0.1:5000/api/tool/test-list-databases/invoke" + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{} - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to list databases: %v", err) - } - - databases, ok := result.([]map[string]any) + databases, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []map[string]any, got %T", result) + t.Fatalf("unable to find result in response body") + } + var res []map[string]any + err = json.Unmarshal([]byte(databases), &res) + if err != nil { + t.Errorf("error parsing result") } // Should contain at least the default database and our test database - system and default - if len(databases) < 2 { - t.Errorf("Expected at least 2 databases, got %d", len(databases)) + if len(res) < 2 { + t.Errorf("Expected at least 2 databases, got %d", len(res)) } found := false foundDefault := false - for _, db := range databases { + for _, db := range res { if name, ok := db["name"].(string); ok { if name == testDBName { found = true @@ -1095,21 +981,12 @@ func TestClickHouseListDatabasesTool(t *testing.T) { }) t.Run("ListDatabasesWithInvalidSource", func(t *testing.T) { - toolConfig := clickhouselistdatabases.Config{ - Name: "test-invalid-source", - Kind: "clickhouse-list-databases", - Source: "non-existent-source", - Description: "Test with invalid source", + api := "http://127.0.0.1:5000/api/tool/test-invalid-source/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Fatalf("expected error for non-existent source, but got 200 OK") } - sourcesMap := map[string]sources.Source{} - - _, err := toolConfig.Initialize(sourcesMap) - if err == nil { - t.Error("Expected error for non-existent source, got nil") - } else { - t.Logf("Got expected error for invalid source: %v", err) - } }) t.Logf("✅ clickhouse-list-databases tool tests completed successfully") @@ -1148,46 +1025,71 @@ func TestClickHouseListTablesTool(t *testing.T) { t.Fatalf("Failed to create test table 2: %v", err) } + toolsFile := map[string]any{ + "sources": map[string]any{ + "my-instance": getClickHouseVars(t), + }, + "tools": map[string]any{ + "test-list-tables": map[string]any{ + "kind": "clickhouse-list-tables", + "source": "my-instance", + "description": "Test listing tables", + }, + "test-invalid-source": map[string]any{ + "kind": "clickhouse-list-tables", + "source": "non-existent-source", + "description": "Test with invalid source", + }, + }, + } + + var args []string + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + t.Run("ListTables", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-list-tables", - Kind: "clickhouse-list-tables", - Source: "test-clickhouse", - Description: "Test listing tables", + api := "http://127.0.0.1:5000/api/tool/test-list-tables/invoke" + params := fmt.Sprintf(`{"database": "%s"}`, testDBName) + resp, respBody := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(params)), nil) + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) } - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) + var body map[string]interface{} + err := json.Unmarshal(respBody, &body) if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) + t.Fatalf("error parsing response body") } - params := parameters.ParamValues{ - {Name: "database", Value: testDBName}, - } - - result, err := tool.Invoke(ctx, params, "") - if err != nil { - t.Fatalf("Failed to list tables: %v", err) - } - - tables, ok := result.([]map[string]any) + tables, ok := body["result"].(string) if !ok { - t.Fatalf("Expected result to be []map[string]any, got %T", result) + t.Fatalf("Expected result to be []map[string]any, got %T", tables) + } + var res []map[string]any + err = json.Unmarshal([]byte(tables), &res) + if err != nil { + t.Errorf("error parsing result") } // Should contain exactly 2 tables that we created - if len(tables) != 2 { - t.Errorf("Expected 2 tables, got %d", len(tables)) + if len(res) != 2 { + t.Errorf("Expected 2 tables, got %d", len(res)) } foundTable1 := false foundTable2 := false - for _, table := range tables { + for _, table := range res { if name, ok := table["name"].(string); ok { if name == testTable1 { foundTable1 = true @@ -1215,48 +1117,18 @@ func TestClickHouseListTablesTool(t *testing.T) { }) t.Run("ListTablesWithMissingDatabase", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-list-tables-missing-db", - Kind: "clickhouse-list-tables", - Source: "test-clickhouse", - Description: "Test listing tables without database parameter", - } - - source := createMockSource(t, pool) - sourcesMap := map[string]sources.Source{ - "test-clickhouse": source, - } - - tool, err := toolConfig.Initialize(sourcesMap) - if err != nil { - t.Fatalf("Failed to initialize tool: %v", err) - } - - params := parameters.ParamValues{} - - _, err = tool.Invoke(ctx, params, "") - if err == nil { - t.Error("Expected error for missing database parameter, got nil") - } else { - t.Logf("Got expected error for missing database: %v", err) + api := "http://127.0.0.1:5000/api/tool/test-list-tables/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Error("Expected error for missing database parameter, but got 200 OK") } }) t.Run("ListTablesWithInvalidSource", func(t *testing.T) { - toolConfig := clickhouselisttables.Config{ - Name: "test-invalid-source", - Kind: "clickhouse-list-tables", - Source: "non-existent-source", - Description: "Test with invalid source", - } - - sourcesMap := map[string]sources.Source{} - - _, err := toolConfig.Initialize(sourcesMap) - if err == nil { - t.Error("Expected error for non-existent source, got nil") - } else { - t.Logf("Got expected error for invalid source: %v", err) + api := "http://127.0.0.1:5000/api/tool/test-invalid-source/invoke" + resp, _ := tests.RunRequest(t, http.MethodPost, api, bytes.NewBuffer([]byte(`{}`)), nil) + if resp.StatusCode == http.StatusOK { + t.Error("Expected error for non-existent source, but got 200 OK") } }) diff --git a/tests/cloudgda/cloud_gda_integration_test.go b/tests/cloudgda/cloud_gda_integration_test.go new file mode 100644 index 0000000000..3a7c8ad07f --- /dev/null +++ b/tests/cloudgda/cloud_gda_integration_test.go @@ -0,0 +1,233 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudgda_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + "time" + + "github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cloudgda" + "github.com/googleapis/genai-toolbox/tests" +) + +var ( + cloudGdaToolKind = "cloud-gemini-data-analytics-query" +) + +type cloudGdaTransport struct { + transport http.RoundTripper + url *url.URL +} + +func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") { + req.URL.Scheme = t.url.Scheme + req.URL.Host = t.url.Host + } + return t.transport.RoundTrip(req) +} + +type masterHandler struct { + t *testing.T +} + +func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.UserAgent(), "genai-toolbox/") { + h.t.Errorf("User-Agent header not found") + } + + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Verify URL structure + // Expected: /v1beta/projects/{project}/locations/global:queryData + if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") { + h.t.Errorf("unexpected URL path: %s", r.URL.Path) + http.Error(w, "Not found", http.StatusNotFound) + return + } + + var reqBody cloudgda.QueryDataRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + h.t.Fatalf("failed to decode request body: %v", err) + } + + if reqBody.Prompt == "" { + http.Error(w, "missing prompt", http.StatusBadRequest) + return + } + + response := map[string]any{ + "queryResult": "SELECT * FROM table;", + "naturalLanguageAnswer": "Here is the answer.", + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func TestCloudGdaToolEndpoints(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + handler := &masterHandler{t: t} + server := httptest.NewServer(handler) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + originalTransport := http.DefaultClient.Transport + if originalTransport == nil { + originalTransport = http.DefaultTransport + } + http.DefaultClient.Transport = &cloudGdaTransport{ + transport: originalTransport, + url: serverURL, + } + t.Cleanup(func() { + http.DefaultClient.Transport = originalTransport + }) + + var args []string + toolsFile := getCloudGdaToolsConfig() + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + toolName := "cloud-gda-query" + + // 1. RunToolGetTestByName + expectedManifest := map[string]any{ + toolName: map[string]any{ + "description": "Test GDA Tool", + "parameters": []any{ + map[string]any{ + "name": "prompt", + "type": "string", + "description": "The natural language question to ask.", + "required": true, + "authSources": []any{}, + }, + }, + "authRequired": []any{}, + }, + } + tests.RunToolGetTestByName(t, toolName, expectedManifest) + + // 2. RunToolInvokeParametersTest + params := []byte(`{"prompt": "test question"}`) + tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"") + + // 3. Manual MCP Tool Call Test + // Initialize MCP session + sessionId := tests.RunInitialize(t, "2024-11-05") + + // Construct MCP Request + mcpReq := jsonrpc.JSONRPCRequest{ + Jsonrpc: "2.0", + Id: "test-mcp-call", + Request: jsonrpc.Request{ + Method: "tools/call", + }, + Params: map[string]any{ + "name": toolName, + "arguments": map[string]any{ + "prompt": "test question", + }, + }, + } + reqBytes, _ := json.Marshal(mcpReq) + + headers := map[string]string{} + if sessionId != "" { + headers["Mcp-Session-Id"] = sessionId + } + + // Send Request + resp, respBody := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBytes), headers) + + if resp.StatusCode != http.StatusOK { + t.Fatalf("MCP request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + // Check Response + respStr := string(respBody) + if !strings.Contains(respStr, "SELECT * FROM table;") { + t.Errorf("MCP response does not contain expected query result: %s", respStr) + } +} + +func getCloudGdaToolsConfig() map[string]any { + // Mocked responses and a dummy `projectId` are used in this integration + // test due to limited project-specific allowlisting. API functionality is + // verified via internal monitoring; this test specifically validates the + // integration flow between the source and the tool. + return map[string]any{ + "sources": map[string]any{ + "my-gda-source": map[string]any{ + "kind": "cloud-gemini-data-analytics", + "projectId": "test-project", + }, + }, + "tools": map[string]any{ + "cloud-gda-query": map[string]any{ + "kind": cloudGdaToolKind, + "source": "my-gda-source", + "description": "Test GDA Tool", + "location": "us-central1", + "context": map[string]any{ + "datasourceReferences": map[string]any{ + "spannerReference": map[string]any{ + "databaseReference": map[string]any{ + "projectId": "test-project", + "instanceId": "test-instance", + "databaseId": "test-db", + "engine": "GOOGLE_SQL", + }, + }, + }, + }, + }, + }, + } +} diff --git a/tests/cloudmonitoring/cloud_monitoring_integration_test.go b/tests/cloudmonitoring/cloud_monitoring_integration_test.go index e451e9686c..f5833244a6 100644 --- a/tests/cloudmonitoring/cloud_monitoring_integration_test.go +++ b/tests/cloudmonitoring/cloud_monitoring_integration_test.go @@ -53,8 +53,6 @@ func TestTool_Invoke(t *testing.T) { Description: "Test Cloudmonitoring Tool", }, AllParams: parameters.Parameters{}, - BaseURL: server.URL, - Client: &http.Client{}, } // Define the test parameters @@ -64,7 +62,7 @@ func TestTool_Invoke(t *testing.T) { } // Invoke the tool - result, err := tool.Invoke(context.Background(), params, "") + result, err := tool.Invoke(context.Background(), nil, params, "") if err != nil { t.Fatalf("Invoke() error = %v", err) } @@ -99,8 +97,6 @@ func TestTool_Invoke_Error(t *testing.T) { Description: "Test Cloudmonitoring Tool", }, AllParams: parameters.Parameters{}, - BaseURL: server.URL, - Client: &http.Client{}, } // Define the test parameters @@ -110,7 +106,7 @@ func TestTool_Invoke_Error(t *testing.T) { } // Invoke the tool - _, err := tool.Invoke(context.Background(), params, "") + _, err := tool.Invoke(context.Background(), nil, params, "") if err == nil { t.Fatal("Invoke() error = nil, want error") } diff --git a/tests/cloudsql/cloud_sql_clone_instance_test.go b/tests/cloudsql/cloud_sql_clone_instance_test.go new file mode 100644 index 0000000000..f41062cb9a --- /dev/null +++ b/tests/cloudsql/cloud_sql_clone_instance_test.go @@ -0,0 +1,244 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudsql + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" + sqladmin "google.golang.org/api/sqladmin/v1" + + _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance" +) + +var ( + cloneInstanceToolKind = "cloud-sql-clone-instance" +) + +type cloneInstanceTransport struct { + transport http.RoundTripper + url *url.URL +} + +func (t *cloneInstanceTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") { + req.URL.Scheme = t.url.Scheme + req.URL.Host = t.url.Host + } + return t.transport.RoundTrip(req) +} + +type masterCloneInstanceHandler struct { + t *testing.T +} + +func (h *masterCloneInstanceHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.UserAgent(), "genai-toolbox/") { + h.t.Errorf("User-Agent header not found") + } + var body sqladmin.InstancesCloneRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + h.t.Fatalf("failed to decode request body: %v", err) + } else { + h.t.Logf("Received request body: %+v", body) + } + + var expectedBody sqladmin.InstancesCloneRequest + var response any + var statusCode int + + switch body.CloneContext.DestinationInstanceName { + case "cloned-instance": + expectedBody = sqladmin.InstancesCloneRequest{ + CloneContext: &sqladmin.CloneContext{ + DestinationInstanceName: "cloned-instance", + }, + } + response = map[string]any{"name": "op1", "status": "PENDING"} + statusCode = http.StatusOK + case "cloned-pitr-instance": + expectedBody = sqladmin.InstancesCloneRequest{ + CloneContext: &sqladmin.CloneContext{ + DestinationInstanceName: "cloned-pitr-instance", + PointInTime: "2025-11-04T10:00:00Z", + }, + } + response = map[string]any{"name": "op2", "status": "PENDING"} + statusCode = http.StatusOK + default: + http.Error(w, fmt.Sprintf("unhandled destination instance name: %s", body.CloneContext.DestinationInstanceName), http.StatusInternalServerError) + return + } + + if diff := cmp.Diff(expectedBody, body); diff != "" { + h.t.Errorf("unexpected request body (-want +got):\n%s", diff) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func TestCloneInstanceToolEndpoints(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + handler := &masterCloneInstanceHandler{t: t} + server := httptest.NewServer(handler) + defer server.Close() + + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("failed to parse server URL: %v", err) + } + + originalTransport := http.DefaultClient.Transport + if originalTransport == nil { + originalTransport = http.DefaultTransport + } + http.DefaultClient.Transport = &cloneInstanceTransport{ + transport: originalTransport, + url: serverURL, + } + t.Cleanup(func() { + http.DefaultClient.Transport = originalTransport + }) + + var args []string + toolsFile := getCloneInstanceToolsConfig() + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + tcs := []struct { + name string + toolName string + body string + want string + expectError bool + errorStatus int + }{ + { + name: "successful clone instance", + toolName: "clone-instance", + body: `{"project": "p1", "sourceInstanceName": "source-instance", "destinationInstanceName": "cloned-instance"}`, + want: `{"name":"op1","status":"PENDING"}`, + }, + { + name: "successful pitr clone instance", + toolName: "clone-instance", + body: `{"project": "p1", "sourceInstanceName": "source-instance", "destinationInstanceName": "cloned-pitr-instance", "pointInTime": "2025-11-04T10:00:00Z"}`, + want: `{"name":"op2","status":"PENDING"}`, + }, + { + name: "missing destination instance name", + toolName: "clone-instance", + body: `{"project": "p1", "sourceInstanceName": "source-instance"}`, + expectError: true, + errorStatus: http.StatusBadRequest, + }, + } + + for _, tc := range tcs { + tc := tc + t.Run(tc.name, func(t *testing.T) { + api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName) + req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body)) + if err != nil { + t.Fatalf("unable to create request: %s", err) + } + req.Header.Add("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("unable to send request: %s", err) + } + defer resp.Body.Close() + + if tc.expectError { + if resp.StatusCode != tc.errorStatus { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("expected status %d but got %d: %s", tc.errorStatus, resp.StatusCode, string(bodyBytes)) + } + return + } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var result struct { + Result string `json:"result"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + var got, want map[string]any + if err := json.Unmarshal([]byte(result.Result), &got); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal want: %v", err) + } + + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected result: got %+v, want %+v", got, want) + } + }) + } +} + +func getCloneInstanceToolsConfig() map[string]any { + return map[string]any{ + "sources": map[string]any{ + "my-cloud-sql-source": map[string]any{ + "kind": "cloud-sql-admin", + }, + }, + "tools": map[string]any{ + "clone-instance": map[string]any{ + "kind": cloneInstanceToolKind, + "source": "my-cloud-sql-source", + }, + }, + } +} diff --git a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go index 551cb6ffe8..55b3035868 100644 --- a/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go +++ b/tests/cloudsqlmysql/cloud_sql_mysql_integration_test.go @@ -160,8 +160,10 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) { tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) // Run specific MySQL tool tests - tests.RunMySQLListTablesTest(t, CloudSQLMySQLDatabase, tableNameParam, tableNameAuth) + const expectedOwner = "'toolbox-identity'@'%'" + tests.RunMySQLListTablesTest(t, CloudSQLMySQLDatabase, tableNameParam, tableNameAuth, expectedOwner) tests.RunMySQLListActiveQueriesTest(t, ctx, pool) + tests.RunMySQLGetQueryPlanTest(t, ctx, pool, CloudSQLMySQLDatabase, tableNameParam) } // Test connection with different IP type @@ -191,3 +193,104 @@ func TestCloudSQLMySQLIpConnection(t *testing.T) { }) } } + +func TestCloudSQLMySQLIAMConnection(t *testing.T) { + getCloudSQLMySQLVars(t) + // service account email used for IAM should trim the suffix + serviceAccountEmail, _, _ := strings.Cut(tests.ServiceAccountEmail, "@") + + noPassSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + "user": serviceAccountEmail, + } + noUserSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + "password": "random", + } + noUserNoPassSourceConfig := map[string]any{ + "kind": CloudSQLMySQLSourceKind, + "project": CloudSQLMySQLProject, + "instance": CloudSQLMySQLInstance, + "region": CloudSQLMySQLRegion, + "database": CloudSQLMySQLDatabase, + } + tcs := []struct { + name string + sourceConfig map[string]any + isErr bool + }{ + { + name: "no user no pass", + sourceConfig: noUserNoPassSourceConfig, + isErr: false, + }, + { + name: "no password", + sourceConfig: noPassSourceConfig, + isErr: false, + }, + { + name: "no user", + sourceConfig: noUserSourceConfig, + isErr: true, + }, + } + for i, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + // Generate a UNIQUE source name for this test case. + // It ensures the app registers a unique driver name + // like "cloudsql-mysql-iam-test-0", preventing conflicts. + uniqueSourceName := fmt.Sprintf("iam-test-%d", i) + + // Construct the tools config manually (Copied from RunSourceConnectionTest) + toolsFile := map[string]any{ + "sources": map[string]any{ + uniqueSourceName: tc.sourceConfig, + }, + "tools": map[string]any{ + "my-simple-tool": map[string]any{ + "kind": CloudSQLMySQLToolKind, + "source": uniqueSourceName, + "description": "Simple tool to test end to end functionality.", + "statement": "SELECT 1;", + }, + }, + } + + // Start the Toolbox Command + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + // Wait for the server to be ready + waitCtx, waitCancel := context.WithTimeout(ctx, 10*time.Second) + defer waitCancel() + + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + if tc.isErr { + return + } + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("Connection test failure: toolbox didn't start successfully: %s", err) + } + + if tc.isErr { + t.Fatalf("Expected error but test passed.") + } + }) + } +} diff --git a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go index d2960c98be..4879f19035 100644 --- a/tests/cloudsqlpg/cloud_sql_pg_integration_test.go +++ b/tests/cloudsqlpg/cloud_sql_pg_integration_test.go @@ -165,7 +165,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { // Run Postgres prebuilt tool tests tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, CloudSQLPostgresUser) - tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) + tests.RunPostgresListViewsTest(t, ctx, pool) tests.RunPostgresListSchemasTest(t, ctx, pool) tests.RunPostgresListActiveQueriesTest(t, ctx, pool) tests.RunPostgresListAvailableExtensionsTest(t) @@ -179,6 +179,12 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) { tests.RunPostgresLongRunningTransactionsTest(t, ctx, pool) tests.RunPostgresListQueryStatsTest(t, ctx, pool) tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool) + tests.RunPostgresListTableStatsTest(t, ctx, pool) + tests.RunPostgresListPublicationTablesTest(t, ctx, pool) + tests.RunPostgresListTableSpacesTest(t) + tests.RunPostgresListPgSettingsTest(t, ctx, pool) + tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) + tests.RunPostgresListRolesTest(t, ctx, pool) } // Test connection with different IP type diff --git a/tests/common.go b/tests/common.go index aaa3dfbf96..5ada5a6b32 100644 --- a/tests/common.go +++ b/tests/common.go @@ -207,6 +207,12 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a PostgresReplicationStatsToolKind = "postgres-replication-stats" PostgresListQueryStatsToolKind = "postgres-list-query-stats" PostgresGetColumnCardinalityToolKind = "postgres-get-column-cardinality" + PostgresListTableStats = "postgres-list-table-stats" + PostgresListPublicationTablesToolKind = "postgres-list-publication-tables" + PostgresListTablespacesToolKind = "postgres-list-tablespaces" + PostgresListPGSettingsToolKind = "postgres-list-pg-settings" + PostgresListDatabaseStatsToolKind = "postgres-list-database-stats" + PostgresListRolesToolKind = "postgres-list-roles" ) tools, ok := config["tools"].(map[string]any) @@ -223,34 +229,28 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a "source": "my-instance", "description": "Lists active queries in the database.", } - tools["list_installed_extensions"] = map[string]any{ "kind": PostgresListInstalledExtensionsToolKind, "source": "my-instance", "description": "Lists installed extensions in the database.", } - tools["list_available_extensions"] = map[string]any{ "kind": PostgresListAvailableExtensionsToolKind, "source": "my-instance", "description": "Lists available extensions in the database.", } - tools["list_views"] = map[string]any{ "kind": PostgresListViewsToolKind, "source": "my-instance", } - tools["list_schemas"] = map[string]any{ "kind": PostgresListSchemasToolKind, "source": "my-instance", } - tools["database_overview"] = map[string]any{ "kind": PostgresDatabaseOverviewToolKind, "source": "my-instance", } - tools["list_triggers"] = map[string]any{ "kind": PostgresListTriggersToolKind, "source": "my-instance", @@ -259,12 +259,14 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a "kind": PostgresListIndexesToolKind, "source": "my-instance", } - tools["list_sequences"] = map[string]any{ "kind": PostgresListSequencesToolKind, "source": "my-instance", } - + tools["list_publication_tables"] = map[string]any{ + "kind": PostgresListPublicationTablesToolKind, + "source": "my-instance", + } tools["long_running_transactions"] = map[string]any{ "kind": PostgresLongRunningTransactionsToolKind, "source": "my-instance", @@ -281,12 +283,33 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a "kind": PostgresListQueryStatsToolKind, "source": "my-instance", } - tools["get_column_cardinality"] = map[string]any{ "kind": PostgresGetColumnCardinalityToolKind, "source": "my-instance", } + tools["list_table_stats"] = map[string]any{ + "kind": PostgresListTableStats, + "source": "my-instance", + } + + tools["list_tablespaces"] = map[string]any{ + "kind": PostgresListTablespacesToolKind, + "source": "my-instance", + } + tools["list_pg_settings"] = map[string]any{ + "kind": PostgresListPGSettingsToolKind, + "source": "my-instance", + } + tools["list_database_stats"] = map[string]any{ + "kind": PostgresListDatabaseStatsToolKind, + "source": "my-instance", + } + + tools["list_roles"] = map[string]any{ + "kind": PostgresListRolesToolKind, + "source": "my-instance", + } config["tools"] = tools return config } @@ -425,6 +448,11 @@ func AddMySQLPrebuiltToolConfig(t *testing.T, config map[string]any) map[string] "source": "my-instance", "description": "Lists table fragmentation in the database.", } + tools["get_query_plan"] = map[string]any{ + "kind": "mysql-get-query-plan", + "source": "my-instance", + "description": "Gets the query plan for a SQL statement.", + } config["tools"] = tools return config } @@ -870,7 +898,7 @@ func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) { project: my-project region: my-region instance: my-instance - ipType: private + ipType: private database: my_db user: my_user password: my_pass @@ -910,7 +938,7 @@ func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) { // Finds and drops all tables in a postgres database. func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { query := ` - SELECT table_name FROM information_schema.tables + SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE';` rows, err := pool.Query(ctx, query) @@ -943,7 +971,7 @@ func CleanupPostgresTables(t *testing.T, ctx context.Context, pool *pgxpool.Pool // Finds and drops all tables in a mysql database. func CleanupMySQLTables(t *testing.T, ctx context.Context, pool *sql.DB) { query := ` - SELECT table_name FROM information_schema.tables + SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE';` rows, err := pool.QueryContext(ctx, query) diff --git a/tests/looker/looker_integration_test.go b/tests/looker/looker_integration_test.go index b7272c5915..06ee5c0277 100644 --- a/tests/looker/looker_integration_test.go +++ b/tests/looker/looker_integration_test.go @@ -30,6 +30,9 @@ import ( "github.com/googleapis/genai-toolbox/internal/testutils" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/tests" + + "github.com/looker-open-source/sdk-codegen/go/rtl" + v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4" ) var ( @@ -139,11 +142,31 @@ func TestLooker(t *testing.T) { "source": "my-instance", "description": "Simple tool to test end to end functionality.", }, + "make_look": map[string]any{ + "kind": "looker-make-look", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, "get_dashboards": map[string]any{ "kind": "looker-get-dashboards", "source": "my-instance", "description": "Simple tool to test end to end functionality.", }, + "make_dashboard": map[string]any{ + "kind": "looker-make-dashboard", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, + "add_dashboard_filter": map[string]any{ + "kind": "looker-add-dashboard-filter", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, + "add_dashboard_element": map[string]any{ + "kind": "looker-add-dashboard-element", + "source": "my-instance", + "description": "Simple tool to test end to end functionality.", + }, "conversational_analytics": map[string]any{ "kind": "looker-conversational-analytics", "source": "my-instance", @@ -678,6 +701,116 @@ func TestLooker(t *testing.T) { }, }, ) + tests.RunToolGetTestByName(t, "make_look", + map[string]any{ + "make_look": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The model containing the explore.", + "name": "model", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The explore to be queried.", + "name": "explore", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The fields to be retrieved.", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be returned in the query", + "name": "field", + "required": true, + "type": "string", + }, + "name": "fields", + "required": true, + "type": "array", + }, + map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": "The filters for the query", + "name": "filters", + "required": false, + "type": "object", + }, + map[string]any{ + "authSources": []any{}, + "description": "The query pivots (must be included in fields as well).", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be used as a pivot in the query", + "name": "pivot_field", + "required": false, + "type": "string", + }, + "name": "pivots", + "required": false, + "type": "array", + }, + map[string]any{ + "authSources": []any{}, + "description": "The sorts like \"field.id desc 0\".", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be used as a sort in the query", + "name": "sort_field", + "required": false, + "type": "string", + }, + "name": "sorts", + "required": false, + "type": "array", + }, + map[string]any{ + "authSources": []any{}, + "description": "The row limit.", + "name": "limit", + "required": false, + "type": "integer", + }, + map[string]any{ + "authSources": []any{}, + "description": "The query timezone.", + "name": "tz", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The title of the Look", + "name": "title", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The description of the Look", + "name": "description", + "required": false, + "type": "string", + }, + map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": "The visualization config for the query", + "name": "vis_config", + "required": false, + "type": "object", + }, + }, + }, + }, + ) tests.RunToolGetTestByName(t, "get_dashboards", map[string]any{ "get_dashboards": map[string]any{ @@ -716,6 +849,235 @@ func TestLooker(t *testing.T) { }, }, ) + tests.RunToolGetTestByName(t, "make_dashboard", + map[string]any{ + "make_dashboard": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The title of the Dashboard", + "name": "title", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The description of the Dashboard", + "name": "description", + "required": false, + "type": "string", + }, + }, + }, + }, + ) + tests.RunToolGetTestByName(t, "add_dashboard_filter", + map[string]any{ + "add_dashboard_filter": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The id of the dashboard where this filter will exist", + "name": "dashboard_id", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The name of the Dashboard Filter", + "name": "name", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The title of the Dashboard Filter", + "name": "title", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The filter_type of the Dashboard Filter: date_filter, number_filter, string_filter, field_filter (default field_filter)", + "name": "filter_type", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The default_value of the Dashboard Filter (optional)", + "name": "default_value", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The model of a field type Dashboard Filter (required if type field)", + "name": "model", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The explore of a field type Dashboard Filter (required if type field)", + "name": "explore", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The dimension of a field type Dashboard Filter (required if type field)", + "name": "dimension", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The Dashboard Filter should allow multiple values (default true)", + "name": "allow_multiple_values", + "required": false, + "type": "boolean", + }, + map[string]any{ + "authSources": []any{}, + "description": "The Dashboard Filter is required to run dashboard (default false)", + "name": "required", + "required": false, + "type": "boolean", + }, + }, + }, + }, + ) + tests.RunToolGetTestByName(t, "add_dashboard_element", + map[string]any{ + "add_dashboard_element": map[string]any{ + "description": "Simple tool to test end to end functionality.", + "authRequired": []any{}, + "parameters": []any{ + map[string]any{ + "authSources": []any{}, + "description": "The model containing the explore.", + "name": "model", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The explore to be queried.", + "name": "explore", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The fields to be retrieved.", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be returned in the query", + "name": "field", + "required": true, + "type": "string", + }, + "name": "fields", + "required": true, + "type": "array", + }, + map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": "The filters for the query", + "name": "filters", + "required": false, + "type": "object", + }, + map[string]any{ + "authSources": []any{}, + "description": "The query pivots (must be included in fields as well).", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be used as a pivot in the query", + "name": "pivot_field", + "required": false, + "type": "string", + }, + "name": "pivots", + "required": false, + "type": "array", + }, + map[string]any{ + "authSources": []any{}, + "description": "The sorts like \"field.id desc 0\".", + "items": map[string]any{ + "authSources": []any{}, + "description": "A field to be used as a sort in the query", + "name": "sort_field", + "required": false, + "type": "string", + }, + "name": "sorts", + "required": false, + "type": "array", + }, + map[string]any{ + "authSources": []any{}, + "description": "The row limit.", + "name": "limit", + "required": false, + "type": "integer", + }, + map[string]any{ + "authSources": []any{}, + "description": "The query timezone.", + "name": "tz", + "required": false, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The id of the dashboard where this tile will exist", + "name": "dashboard_id", + "required": true, + "type": "string", + }, + map[string]any{ + "authSources": []any{}, + "description": "The title of the Dashboard Element", + "name": "title", + "required": false, + "type": "string", + }, + map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": "The visualization config for the query", + "name": "vis_config", + "required": false, + "type": "object", + }, + map[string]any{ + "authSources": []any{}, + "description": `An array of dashboard filters like [{"dashboard_filter_name": "name", "field": "view_name.field_name"}, ...]`, + "items": map[string]any{ + "additionalProperties": true, + "authSources": []any{}, + "description": `A dashboard filter like {"dashboard_filter_name": "name", "field": "view_name.field_name"}`, + "name": "dashboard_filter", + "required": false, + "type": "object", + }, + "name": "dashboard_filters", + "required": false, + "type": "array", + }, + }, + }, + }, + ) tests.RunToolGetTestByName(t, "conversational_analytics", map[string]any{ "conversational_analytics": map[string]any{ @@ -1200,8 +1562,6 @@ func TestLooker(t *testing.T) { wantResult = "null" tests.RunToolInvokeParametersTest(t, "get_dashboards", []byte(`{"title": "FOO", "desc": "BAR"}`), wantResult) - runConversationalAnalytics(t, "system__activity", "content_usage") - wantResult = "\"Connection\":\"thelook\"" tests.RunToolInvokeParametersTest(t, "health_pulse", []byte(`{"action": "check_db_connections"}`), wantResult) @@ -1261,6 +1621,16 @@ func TestLooker(t *testing.T) { wantResult = "/login/embed?t=" // testing for specific substring, since url is dynamic tests.RunToolInvokeParametersTest(t, "generate_embed_url", []byte(`{"type": "dashboards", "id": "1"}`), wantResult) + + runConversationalAnalytics(t, "system__activity", "content_usage") + + deleteLook := testMakeLook(t) + defer deleteLook() + + dashboardId, deleteDashboard := testMakeDashboard(t) + defer deleteDashboard() + testAddDashboardFilter(t, dashboardId) + testAddDashboardElement(t, dashboardId) } func runConversationalAnalytics(t *testing.T, modelName, exploreName string) { @@ -1325,3 +1695,122 @@ func runConversationalAnalytics(t *testing.T, modelName, exploreName string) { }) } } + +func newLookerTestSDK(t *testing.T) *v4.LookerSDK { + t.Helper() + cfg := rtl.ApiSettings{ + BaseUrl: LookerBaseUrl, + ApiVersion: "4.0", + VerifySsl: LookerVerifySsl == "true", + Timeout: 120, + ClientId: LookerClientId, + ClientSecret: LookerClientSecret, + } + return v4.NewLookerSDK(rtl.NewAuthSession(cfg)) +} + +func testMakeLook(t *testing.T) func() { + var id string + t.Run("TestMakeLook", func(t *testing.T) { + reqBody := []byte(`{"model": "system__activity", "explore": "look", "fields": ["look.count"], "title": "TestLook"}`) + + url := "http://127.0.0.1:5000/api/tool/make_look/invoke" + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil) + + if resp.StatusCode != 200 { + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes)) + } + + var respBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &respBody); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + result := respBody["result"].(string) + if err := json.Unmarshal([]byte(result), &respBody); err != nil { + t.Fatalf("error parsing result body: %v", err) + } + + var ok bool + if id, ok = respBody["id"].(string); !ok || id == "" { + t.Fatalf("didn't get TestLook id, got %s", string(bodyBytes)) + } + }) + + return func() { + sdk := newLookerTestSDK(t) + + if _, err := sdk.DeleteLook(id, nil); err != nil { + t.Fatalf("error deleting look: %v", err) + } + t.Logf("deleted Look %s", id) + } +} + +func testAddDashboardFilter(t *testing.T, dashboardId string) { + t.Run("TestAddDashboardFilter", func(t *testing.T) { + reqBody := []byte(fmt.Sprintf(`{"dashboard_id": "%s", "model": "system__activity", "explore": "look", "dimension": "look.created_year", "name": "test_filter", "title": "TestDashboardFilter"}`, dashboardId)) + + url := "http://127.0.0.1:5000/api/tool/add_dashboard_filter/invoke" + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil) + + if resp.StatusCode != 200 { + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes)) + } + + t.Logf("got %s", string(bodyBytes)) + }) +} + +func testAddDashboardElement(t *testing.T, dashboardId string) { + t.Run("TestAddDashboardElement", func(t *testing.T) { + reqBody := []byte(fmt.Sprintf(`{"dashboard_id": "%s", "model": "system__activity", "explore": "look", "fields": ["look.count"], "title": "TestDashboardElement"}`, dashboardId)) + + url := "http://127.0.0.1:5000/api/tool/add_dashboard_element/invoke" + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil) + + if resp.StatusCode != 200 { + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes)) + } + + t.Logf("got %s", string(bodyBytes)) + }) +} + +func testMakeDashboard(t *testing.T) (string, func()) { + var id string + t.Run("TestMakeDashboard", func(t *testing.T) { + reqBody := []byte(`{"title": "TestDashboard"}`) + + url := "http://127.0.0.1:5000/api/tool/make_dashboard/invoke" + resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil) + + if resp.StatusCode != 200 { + t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes)) + } + + var respBody map[string]interface{} + if err := json.Unmarshal(bodyBytes, &respBody); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + result := respBody["result"].(string) + if err := json.Unmarshal([]byte(result), &respBody); err != nil { + t.Fatalf("error parsing result body: %v", err) + } + + var ok bool + if id, ok = respBody["id"].(string); !ok || id == "" { + t.Fatalf("didn't get TestDashboard id, got %s", string(bodyBytes)) + } + }) + + return id, func() { + sdk := newLookerTestSDK(t) + + if _, err := sdk.DeleteDashboard(id, nil); err != nil { + t.Fatalf("error deleting dashboard: %v", err) + } + t.Logf("deleted Dashboard %s", id) + } +} diff --git a/tests/mariadb/mariadb_integration_test.go b/tests/mariadb/mariadb_integration_test.go new file mode 100644 index 0000000000..60d734ace7 --- /dev/null +++ b/tests/mariadb/mariadb_integration_test.go @@ -0,0 +1,343 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mariadb + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" +) + +var ( + MariaDBSourceKind = "mysql" + MariaDBToolKind = "mysql-sql" + MariaDBDatabase = os.Getenv("MARIADB_DATABASE") + MariaDBHost = os.Getenv("MARIADB_HOST") + MariaDBPort = os.Getenv("MARIADB_PORT") + MariaDBUser = os.Getenv("MARIADB_USER") + MariaDBPass = os.Getenv("MARIADB_PASS") +) + +func getMariaDBVars(t *testing.T) map[string]any { + switch "" { + case MariaDBDatabase: + t.Fatal("'MARIADB_DATABASE' not set") + case MariaDBHost: + t.Fatal("'MARIADB_HOST' not set") + case MariaDBPort: + t.Fatal("'MARIADB_PORT' not set") + case MariaDBUser: + t.Fatal("'MARIADB_USER' not set") + case MariaDBPass: + t.Fatal("'MARIADB_PASS' not set") + } + + return map[string]any{ + "kind": MariaDBSourceKind, + "host": MariaDBHost, + "port": MariaDBPort, + "database": MariaDBDatabase, + "user": MariaDBUser, + "password": MariaDBPass, + } +} + +// Copied over from mysql.go +func initMariaDB(host, port, user, pass, dbname string) (*sql.DB, error) { + dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname) + + // Interact with the driver directly as you normally would + pool, err := sql.Open("mysql", dsn) + if err != nil { + return nil, fmt.Errorf("sql.Open: %w", err) + } + return pool, nil +} + +func TestMySQLToolEndpoints(t *testing.T) { + sourceConfig := getMariaDBVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + + pool, err := initMariaDB(MariaDBHost, MariaDBPort, MariaDBUser, MariaDBPass, MariaDBDatabase) + if err != nil { + t.Fatalf("unable to create MySQL connection pool: %s", err) + } + + // cleanup test environment + tests.CleanupMySQLTables(t, ctx, pool) + + // create table name with UUID + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMySQLAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, MariaDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile) + tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MariaDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + + toolsFile = tests.AddMySQLPrebuiltToolConfig(t, toolsFile) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Get configs for tests + select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := GetMariaDBWants() + + // Run tests + tests.RunToolGetTest(t) + tests.RunToolInvokeTest(t, select1Want, tests.DisableArrayTest()) + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) + tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) + + // Run specific MySQL tool tests + RunMariDBListTablesTest(t, MariaDBDatabase, tableNameParam, tableNameAuth) + tests.RunMySQLListActiveQueriesTest(t, ctx, pool) + tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MariaDBDatabase) + tests.RunMySQLListTableFragmentationTest(t, MariaDBDatabase, tableNameParam, tableNameAuth) +} + +// RunMariDBListTablesTest run tests against the mysql-list-tables tool +func RunMariDBListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) { + type tableInfo struct { + ObjectName string `json:"object_name"` + SchemaName string `json:"schema_name"` + ObjectDetails string `json:"object_details"` + } + + type column struct { + DataType string `json:"data_type"` + ColumnName string `json:"column_name"` + ColumnComment string `json:"column_comment"` + ColumnDefault any `json:"column_default"` + IsNotNullable bool `json:"is_not_nullable"` + OrdinalPosition int `json:"ordinal_position"` + } + + type objectDetails struct { + Owner any `json:"owner"` + Columns []column `json:"columns"` + Comment string `json:"comment"` + Indexes []any `json:"indexes"` + Triggers []any `json:"triggers"` + Constraints []any `json:"constraints"` + ObjectName string `json:"object_name"` + ObjectType string `json:"object_type"` + SchemaName string `json:"schema_name"` + } + + paramTableWant := objectDetails{ + ObjectName: tableNameParam, + SchemaName: databaseName, + ObjectType: "TABLE", + Columns: []column{ + {DataType: "int(11)", ColumnName: "id", IsNotNullable: true, OrdinalPosition: 1, ColumnDefault: nil}, + {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2, ColumnDefault: "NULL"}, + }, + Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": true, "is_unique": true}}, + Triggers: []any{}, + Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}}, + } + + authTableWant := objectDetails{ + ObjectName: tableNameAuth, + SchemaName: databaseName, + ObjectType: "TABLE", + Columns: []column{ + {DataType: "int(11)", ColumnName: "id", IsNotNullable: true, OrdinalPosition: 1, ColumnDefault: nil}, + {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2, ColumnDefault: "NULL"}, + {DataType: "varchar(255)", ColumnName: "email", OrdinalPosition: 3, ColumnDefault: "NULL"}, + }, + Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": true, "is_unique": true}}, + Triggers: []any{}, + Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}}, + } + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want any + isSimple bool + isAllTables bool + }{ + { + name: "invoke list_tables for all tables detailed output", + requestBody: bytes.NewBufferString(`{"table_names":""}`), + wantStatusCode: http.StatusOK, + want: []objectDetails{authTableWant, paramTableWant}, + isAllTables: true, + }, + { + name: "invoke list_tables detailed output", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)), + wantStatusCode: http.StatusOK, + want: []objectDetails{authTableWant}, + }, + { + name: "invoke list_tables simple output", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth)), + wantStatusCode: http.StatusOK, + want: []map[string]any{{"name": tableNameAuth}}, + isSimple: true, + }, + { + name: "invoke list_tables with multiple table names", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth)), + wantStatusCode: http.StatusOK, + want: []objectDetails{authTableWant, paramTableWant}, + }, + { + name: "invoke list_tables with one existing and one non-existent table", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameAuth)), + wantStatusCode: http.StatusOK, + want: []objectDetails{authTableWant}, + }, + { + name: "invoke list_tables with non-existent table", + requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`), + wantStatusCode: http.StatusOK, + want: []objectDetails{}, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke" + resp, body := tests.RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(body, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got any + if tc.isSimple { + var tables []tableInfo + if err := json.Unmarshal([]byte(resultString), &tables); err != nil { + t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) + } + details := []map[string]any{} + for _, table := range tables { + var d map[string]any + if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { + t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) + } + details = append(details, d) + } + got = details + } else { + var tables []tableInfo + if err := json.Unmarshal([]byte(resultString), &tables); err != nil { + t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) + } + details := []objectDetails{} + for _, table := range tables { + var d objectDetails + if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { + t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) + } + details = append(details, d) + } + got = details + } + + opts := []cmp.Option{ + cmpopts.SortSlices(func(a, b objectDetails) bool { return a.ObjectName < b.ObjectName }), + cmpopts.SortSlices(func(a, b column) bool { return a.ColumnName < b.ColumnName }), + cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }), + } + + // Checking only the current database where the test tables are created to avoid brittle tests. + if tc.isAllTables { + filteredGot := []objectDetails{} + if got != nil { + for _, item := range got.([]objectDetails) { + if item.SchemaName == databaseName { + filteredGot = append(filteredGot, item) + } + } + } + got = filteredGot + } + + if diff := cmp.Diff(tc.want, got, opts...); diff != "" { + t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want) + } + }) + } +} + +// GetMariaDBWants return the expected wants for mariaDB +func GetMariaDBWants() (string, string, string, string) { + select1Want := `[{"1":1}]` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your MariaDB server version for the right syntax to use near 'SELEC 1' at line 1"}],"isError":true}}` + createTableStatement := `"CREATE TABLE t (id INT AUTO_INCREMENT PRIMARY KEY, name TEXT)"` + mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` + return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want +} diff --git a/tests/mongodb/mongodb_integration_test.go b/tests/mongodb/mongodb_integration_test.go index be0a296065..3d84837de8 100644 --- a/tests/mongodb/mongodb_integration_test.go +++ b/tests/mongodb/mongodb_integration_test.go @@ -102,9 +102,10 @@ func TestMongoDBToolEndpoints(t *testing.T) { // Get configs for tests select1Want := `[{"_id":3,"id":3,"name":"Sid"}]` myToolId3NameAliceWant := `[{"_id":5,"id":3,"name":"Alice"}]` - myToolById4Want := `[{"_id":4,"id":4,"name":null}]` + myToolById4Want := `null` mcpMyFailToolWant := `invalid JSON input: missing colon after key ` - mcpMyToolId3NameAliceWant := `{"jsonrpc":"2.0","id":"my-simple-tool","result":{"content":[{"type":"text","text":"{\"_id\":5,\"id\":3,\"name\":\"Alice\"}"}]}}` + mcpMyToolId3NameAliceWant := `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"_id\":5,\"id\":3,\"name\":\"Alice\"}"}]}}` + mcpAuthRequiredWant := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"_id\":3,\"id\":3,\"name\":\"Sid\"}"}]}}` // Run tests tests.RunToolGetTest(t) @@ -115,13 +116,14 @@ func TestMongoDBToolEndpoints(t *testing.T) { ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, select1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolId3NameAliceWant), + tests.WithMcpSelect1Want(mcpAuthRequiredWant), ) delete1Want := "1" deleteManyWant := "2" runToolDeleteInvokeTest(t, delete1Want, deleteManyWant) - insert1Want := `["68666e1035bb36bf1b4d47fb"]` + insert1Want := `"68666e1035bb36bf1b4d47fb"` insertManyWant := `["68667a6436ec7d0363668db7","68667a6436ec7d0363668db8","68667a6436ec7d0363668db9"]` runToolInsertInvokeTest(t, insert1Want, insertManyWant) @@ -444,12 +446,15 @@ func runToolAggregateInvokeTest(t *testing.T, aggregate1Want string, aggregateMa func setupMongoDB(t *testing.T, ctx context.Context, database *mongo.Database) func(*testing.T) { collectionName := "test_collection" + if err := database.Collection(collectionName).Drop(ctx); err != nil { + t.Logf("Warning: failed to drop collection before setup: %v", err) + } + documents := []map[string]any{ {"_id": 1, "id": 1, "name": "Alice", "email": ServiceAccountEmail}, - {"_id": 1, "id": 2, "name": "FakeAlice", "email": "fakeAlice@gmail.com"}, + {"_id": 14, "id": 2, "name": "FakeAlice", "email": "fakeAlice@gmail.com"}, {"_id": 2, "id": 2, "name": "Jane"}, {"_id": 3, "id": 3, "name": "Sid"}, - {"_id": 4, "id": 4, "name": nil}, {"_id": 5, "id": 3, "name": "Alice", "email": "alice@gmail.com"}, {"_id": 6, "id": 100, "name": "ToBeDeleted", "email": "bob@gmail.com"}, {"_id": 7, "id": 101, "name": "ToBeDeleted", "email": "bob1@gmail.com"}, @@ -498,8 +503,6 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str "filterParams": []any{}, "projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`, "database": MongoDbDatabase, - "limit": 1, - "sort": `{ "id": 1 }`, }, "my-tool": map[string]any{ "kind": toolKind, @@ -522,6 +525,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str }, "projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`, "database": MongoDbDatabase, + "limit": 10, }, "my-tool-by-id": map[string]any{ "kind": toolKind, @@ -539,6 +543,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str }, "projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`, "database": MongoDbDatabase, + "limit": 10, }, "my-tool-by-name": map[string]any{ "kind": toolKind, @@ -546,7 +551,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str "description": "Tool to test invocation with params.", "authRequired": []string{}, "collection": "test_collection", - "filterPayload": `{ "name" : {{ .name }} }`, + "filterPayload": `{ "name" : {{json .name }} }`, "filterParams": []map[string]any{ { "name": "name", @@ -557,6 +562,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str }, "projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`, "database": MongoDbDatabase, + "limit": 10, }, "my-array-tool": map[string]any{ "kind": toolKind, @@ -564,7 +570,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str "description": "Tool to test invocation with array.", "authRequired": []string{}, "collection": "test_collection", - "filterPayload": `{ "name": { "$in": {{json .nameArray}} }, "_id": 5 })`, + "filterPayload": `{ "name": { "$in": {{json .nameArray}} }, "_id": 5 }`, "filterParams": []map[string]any{ { "name": "nameArray", @@ -578,6 +584,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str }, "projectPayload": `{ "_id": 1, "id": 1, "name" : 1 }`, "database": MongoDbDatabase, + "limit": 10, }, "my-auth-tool": map[string]any{ "kind": toolKind, @@ -601,6 +608,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str }, "projectPayload": `{ "_id": 0, "name" : 1 }`, "database": MongoDbDatabase, + "limit": 10, }, "my-auth-required-tool": map[string]any{ "kind": toolKind, @@ -613,6 +621,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str "filterPayload": `{ "_id": 3, "id": 3 }`, "filterParams": []any{}, "database": MongoDbDatabase, + "limit": 10, }, "my-fail-tool": map[string]any{ "kind": toolKind, @@ -623,6 +632,7 @@ func getMongoDBToolsConfig(sourceConfig map[string]any, toolKind string) map[str "filterPayload": `{ "id" ; 1 }"}`, "filterParams": []any{}, "database": MongoDbDatabase, + "limit": 10, }, "my-delete-one-tool": map[string]any{ "kind": "mongodb-delete-one", diff --git a/tests/mysql/mysql_integration_test.go b/tests/mysql/mysql_integration_test.go index c3e452e463..113767fd1d 100644 --- a/tests/mysql/mysql_integration_test.go +++ b/tests/mysql/mysql_integration_test.go @@ -138,8 +138,10 @@ func TestMySQLToolEndpoints(t *testing.T) { tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) // Run specific MySQL tool tests - tests.RunMySQLListTablesTest(t, MySQLDatabase, tableNameParam, tableNameAuth) + const expectedOwner = "" + tests.RunMySQLListTablesTest(t, MySQLDatabase, tableNameParam, tableNameAuth, expectedOwner) tests.RunMySQLListActiveQueriesTest(t, ctx, pool) tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MySQLDatabase) tests.RunMySQLListTableFragmentationTest(t, MySQLDatabase, tableNameParam, tableNameAuth) + tests.RunMySQLGetQueryPlanTest(t, ctx, pool, MySQLDatabase, tableNameParam) } diff --git a/tests/option.go b/tests/option.go index e1e6735b5c..09f1d1c886 100644 --- a/tests/option.go +++ b/tests/option.go @@ -119,6 +119,7 @@ func EnableClientAuthTest() InvokeTestOption { // MCPTestConfig represents the various configuration options for mcp tool call tests. type MCPTestConfig struct { myToolId3NameAliceWant string + mcpSelect1Want string supportClientAuth bool supportSelect1Auth bool } @@ -149,6 +150,12 @@ func DisableMcpSelect1AuthTest() McpTestOption { } } +func WithMcpSelect1Want(want string) McpTestOption { + return func(c *MCPTestConfig) { + c.mcpSelect1Want = want + } +} + /* Configurations for RunExecuteSqlToolInvokeTest() */ // ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest() diff --git a/tests/oracle/oracle_integration_test.go b/tests/oracle/oracle_integration_test.go index 04f272a1b8..0021679e9e 100644 --- a/tests/oracle/oracle_integration_test.go +++ b/tests/oracle/oracle_integration_test.go @@ -43,6 +43,7 @@ func getOracleVars(t *testing.T) map[string]any { return map[string]any{ "kind": OracleSourceKind, "connectionString": OracleConnStr, + "useOCI": true, "user": OracleUser, "password": OraclePass, } @@ -50,9 +51,11 @@ func getOracleVars(t *testing.T) map[string]any { // Copied over from oracle.go func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) { - fullConnStr := fmt.Sprintf("oracle://%s:%s@%s", user, pass, connStr) + // Build the full Oracle connection string for godror driver + fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`, + user, pass, connStr) - db, err := sql.Open("oracle", fullConnStr) + db, err := sql.Open("godror", fullConnStr) if err != nil { return nil, fmt.Errorf("unable to open Oracle connection: %w", err) } @@ -116,13 +119,15 @@ func TestOracleSimpleToolEndpoints(t *testing.T) { // Get configs for tests select1Want := "[{\"1\":1}]" - mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ORA-00900: invalid SQL statement\n error occur at position: 0"}],"isError":true}}` + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}` createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"` mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}` // Run tests tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, + tests.DisableOptionalNullParamTest(), + tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"), tests.DisableArrayTest(), ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) diff --git a/tests/postgres/postgres_integration_test.go b/tests/postgres/postgres_integration_test.go index 42715ffb8e..273c6f3014 100644 --- a/tests/postgres/postgres_integration_test.go +++ b/tests/postgres/postgres_integration_test.go @@ -144,7 +144,7 @@ func TestPostgres(t *testing.T) { // Run Postgres prebuilt tool tests tests.RunPostgresListTablesTest(t, tableNameParam, tableNameAuth, PostgresUser) - tests.RunPostgresListViewsTest(t, ctx, pool, tableNameParam) + tests.RunPostgresListViewsTest(t, ctx, pool) tests.RunPostgresListSchemasTest(t, ctx, pool) tests.RunPostgresListActiveQueriesTest(t, ctx, pool) tests.RunPostgresListAvailableExtensionsTest(t) @@ -158,4 +158,10 @@ func TestPostgres(t *testing.T) { tests.RunPostgresReplicationStatsTest(t, ctx, pool) tests.RunPostgresListQueryStatsTest(t, ctx, pool) tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool) + tests.RunPostgresListTableStatsTest(t, ctx, pool) + tests.RunPostgresListPublicationTablesTest(t, ctx, pool) + tests.RunPostgresListTableSpacesTest(t) + tests.RunPostgresListPgSettingsTest(t, ctx, pool) + tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) + tests.RunPostgresListRolesTest(t, ctx, pool) } diff --git a/tests/serverlessspark/serverless_spark_integration_test.go b/tests/serverlessspark/serverless_spark_integration_test.go index f2fa106f8a..c2f245dc4f 100644 --- a/tests/serverlessspark/serverless_spark_integration_test.go +++ b/tests/serverlessspark/serverless_spark_integration_test.go @@ -20,6 +20,7 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "os" "reflect" @@ -47,6 +48,11 @@ var ( serverlessSparkServiceAccount = os.Getenv("SERVERLESS_SPARK_SERVICE_ACCOUNT") ) +const ( + batchURLPrefix = "https://console.cloud.google.com/dataproc/batches/" + logsURLPrefix = "https://console.cloud.google.com/logs/viewer?" +) + func getServerlessSparkVars(t *testing.T) map[string]any { switch "" { case serverlessSparkLocation: @@ -66,7 +72,7 @@ func getServerlessSparkVars(t *testing.T) map[string]any { func TestServerlessSparkToolEndpoints(t *testing.T) { sourceConfig := getServerlessSparkVars(t) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) defer cancel() toolsFile := map[string]any{ @@ -107,6 +113,54 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { "source": "my-spark", "authRequired": []string{"my-google-auth"}, }, + "create-pyspark-batch": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-pyspark-batch-2-3": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "runtimeConfig": map[string]any{"version": "2.3"}, + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-pyspark-batch-with-auth": map[string]any{ + "kind": "serverless-spark-create-pyspark-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, + "create-spark-batch": map[string]any{ + "kind": "serverless-spark-create-spark-batch", + "source": "my-spark", + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-spark-batch-2-3": map[string]any{ + "kind": "serverless-spark-create-spark-batch", + "source": "my-spark", + "runtimeConfig": map[string]any{"version": "2.3"}, + "environmentConfig": map[string]any{ + "executionConfig": map[string]any{ + "serviceAccount": serverlessSparkServiceAccount, + }, + }, + }, + "create-spark-batch-with-auth": map[string]any{ + "kind": "serverless-spark-create-spark-batch", + "source": "my-spark", + "authRequired": []string{"my-google-auth"}, + }, }, } @@ -220,6 +274,216 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { }) }) + t.Run("create-pyspark-batch", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + piPy := "file:///usr/lib/spark/examples/src/main/python/pi.py" + tcs := []struct { + name string + toolName string + request map[string]any + waitForSuccess bool + validate func(t *testing.T, b *dataprocpb.Batch) + }{ + { + name: "no params", + toolName: "create-pyspark-batch", + waitForSuccess: true, + request: map[string]any{"mainFile": piPy}, + }, + // Tests below are just verifying options are set correctly on created batches, + // they don't need to wait for success. + { + name: "with arg", + toolName: "create-pyspark-batch", + request: map[string]any{"mainFile": piPy, "args": []string{"100"}}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + if !cmp.Equal(b.GetPysparkBatch().Args, []string{"100"}) { + t.Errorf("unexpected args: got %v, want %v", b.GetPysparkBatch().Args, []string{"100"}) + } + }, + }, + { + name: "version", + toolName: "create-pyspark-batch", + request: map[string]any{"mainFile": piPy, "version": "2.2"}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + { + name: "version param overrides tool", + toolName: "create-pyspark-batch-2-3", + request: map[string]any{"mainFile": piPy, "version": "2.2"}, + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate) + }) + } + }) + + t.Run("auth", func(t *testing.T) { + t.Parallel() + // Batch creation succeeds even with an invalid main file, but will fail quickly once running. + runAuthTest(t, "create-pyspark-batch-with-auth", map[string]any{"mainFile": "file:///placeholder"}, http.StatusOK) + }) + + t.Run("errors", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + request map[string]any + wantMsg string + }{ + { + name: "missing main file", + request: map[string]any{}, + wantMsg: "parameter \\\"mainFile\\\" is required", + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, "create-pyspark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) + }) + } + }) + }) + + t.Run("create-spark-batch", func(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + toolName string + request map[string]any + waitForSuccess bool + validate func(t *testing.T, b *dataprocpb.Batch) + }{ + { + name: "main class", + toolName: "create-spark-batch", + waitForSuccess: true, + request: javaReq(map[string]any{}), + }, + { + // spark-examples.jar doesn't have a Main-Class, so pick an arbitrary other + // jar that does. Note there's a chance a subminor release of 2.2 will + // upgrade Spark and its dependencies, causing a failure. If that happens, + // find the new ivy jar filename and use that. The alternative would be to + // pin a subminor version, but that's guaranteed to be GC'ed after 1 year, + // whereas 2.2 is old enough it's unlikely to see a Spark version bump. + name: "main jar", + toolName: "create-spark-batch", + waitForSuccess: true, + request: map[string]any{ + "version": "2.2", + "mainJarFile": "file:///usr/lib/spark/jars/ivy-2.5.2.jar", + "args": []string{"-version"}, + }, + }, + // Tests below are just verifying options are set correctly on created batches, + // they don't need to wait for success. + { + name: "with arg", + toolName: "create-spark-batch", + request: javaReq(map[string]any{"args": []string{"100"}}), + validate: func(t *testing.T, b *dataprocpb.Batch) { + if !cmp.Equal(b.GetSparkBatch().Args, []string{"100"}) { + t.Errorf("unexpected args: got %v, want %v", b.GetSparkBatch().Args, []string{"100"}) + } + }, + }, + { + name: "version", + toolName: "create-spark-batch", + request: javaReq(map[string]any{"version": "2.2"}), + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + { + name: "version param overrides tool", + toolName: "create-spark-batch-2-3", + request: javaReq(map[string]any{"version": "2.2"}), + validate: func(t *testing.T, b *dataprocpb.Batch) { + v := b.GetRuntimeConfig().GetVersion() + if v != "2.2" { + t.Errorf("unexpected version: got %v, want 2.2", v) + } + }, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + runCreateSparkBatchTest(t, client, ctx, tc.toolName, tc.request, tc.waitForSuccess, tc.validate) + }) + } + }) + + t.Run("auth", func(t *testing.T) { + t.Parallel() + // Batch creation succeeds even with an invalid main file, but will fail quickly once running. + runAuthTest(t, "create-spark-batch-with-auth", map[string]any{"mainJarFile": "file:///placeholder"}, http.StatusOK) + }) + + t.Run("errors", func(t *testing.T) { + t.Parallel() + tcs := []struct { + name string + request map[string]any + wantMsg string + }{ + { + name: "no main jar or main class", + request: map[string]any{}, + wantMsg: "must provide either mainJarFile or mainClass", + }, + { + name: "both main jar and main class", + request: map[string]any{ + "mainJarFile": "my.jar", + "mainClass": "com.example.MyClass", + }, + wantMsg: "cannot provide both mainJarFile and mainClass", + }, + { + name: "main class without jar files", + request: map[string]any{ + "mainClass": "com.example.MyClass", + }, + wantMsg: "jarFiles is required when mainClass is provided", + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + testError(t, "create-spark-batch", tc.request, http.StatusBadRequest, tc.wantMsg) + }) + } + }) + }) + t.Run("cancel-batch", func(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { @@ -299,13 +563,16 @@ func TestServerlessSparkToolEndpoints(t *testing.T) { } func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCtx context.Context, batch string, desiredStates []dataprocpb.Batch_State, timeout time.Duration) { + t.Logf("waiting %s for batch %s to reach one of %v", timeout, batch, desiredStates) ctx, cancel := context.WithTimeout(parentCtx, timeout) defer cancel() + start := time.Now() + lastLog := start for { select { case <-ctx.Done(): - t.Fatalf("timed out waiting for batch %s to reach one of states %v", batch, desiredStates) + t.Fatalf("timed out waiting for batch %s to reach one of %v", batch, desiredStates) default: } @@ -315,12 +582,18 @@ func waitForBatch(t *testing.T, client *dataproc.BatchControllerClient, parentCt t.Fatalf("failed to get batch %s: %v", batch, err) } + now := time.Now() + if now.Sub(lastLog) >= 30*time.Second { + t.Logf("%s: batch %s is in state %s after %s", t.Name(), batch.Name, batch.State, now.Sub(start)) + lastLog = now + } + if slices.Contains(desiredStates, batch.State) { return } if batch.State == dataprocpb.Batch_FAILED || batch.State == dataprocpb.Batch_CANCELLED || batch.State == dataprocpb.Batch_SUCCEEDED { - t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch, batch.State, desiredStates, batch.StateMessage) + t.Fatalf("batch op %s is in a terminal state %s, but wanted one of %v. State message: %s", batch.Name, batch.State, desiredStates, batch.StateMessage) } time.Sleep(2 * time.Second) } @@ -362,7 +635,7 @@ func createBatch(t *testing.T, client *dataproc.BatchControllerClient, ctx conte } // Wait for the batch to become at least PENDING; it typically takes >10s to go from PENDING to - // RUNNING, giving us plenty of time to cancel it before it completes. + // RUNNING, giving the cancel batch tests plenty of time to cancel it before it completes. waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_PENDING, dataprocpb.Batch_RUNNING}, 1*time.Minute) return meta.Batch } @@ -471,6 +744,17 @@ func runListBatchesTest(t *testing.T, client *dataproc.BatchControllerClient, ct if !reflect.DeepEqual(actual, tc.want) { t.Fatalf("unexpected batches: got %+v, want %+v", actual, tc.want) } + + // want has URLs because it's created from Batch instances by the same utility function + // used by the tool internals. Double-check that the URLs are reasonable. + for _, batch := range tc.want { + if !strings.HasPrefix(batch.ConsoleURL, batchURLPrefix) { + t.Errorf("unexpected consoleUrl in batch: %#v", batch) + } + if !strings.HasPrefix(batch.LogsURL, logsURLPrefix) { + t.Errorf("unexpected logsUrl in batch: %#v", batch) + } + } }) } } @@ -499,8 +783,12 @@ func listBatchesRpc(t *testing.T, client *dataproc.BatchControllerClient, ctx co if !exact && (len(batchPbs) == 0 || len(batchPbs) > n) { t.Fatalf("expected between 1 and %d batches, got %d", n, len(batchPbs)) } + batches, err := serverlesssparklistbatches.ToBatches(batchPbs) + if err != nil { + t.Fatalf("failed to convert batches to JSON: %v", err) + } - return serverlesssparklistbatches.ToBatches(batchPbs) + return batches } func runAuthTest(t *testing.T, toolName string, request map[string]any, wantStatus int) { @@ -600,11 +888,27 @@ func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx c if !ok { t.Fatalf("unable to find result in response body") } + var wrappedResult map[string]any + if err := json.Unmarshal([]byte(result), &wrappedResult); err != nil { + t.Fatalf("error unmarshalling result: %s", err) + } + consoleURL, ok := wrappedResult["consoleUrl"].(string) + if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) { + t.Errorf("unexpected consoleUrl: %v", consoleURL) + } + logsURL, ok := wrappedResult["logsUrl"].(string) + if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) { + t.Errorf("unexpected logsUrl: %v", logsURL) + } + batchJSON, err := json.Marshal(wrappedResult["batch"]) + if err != nil { + t.Fatalf("failed to marshal batch: %v", err) + } // Unmarshal JSON to proto for proto-aware deep comparison. var batch dataprocpb.Batch - if err := protojson.Unmarshal([]byte(result), &batch); err != nil { - t.Fatalf("error unmarshalling result: %s", err) + if err := protojson.Unmarshal(batchJSON, &batch); err != nil { + t.Fatalf("error unmarshalling batch from wrapped result: %s", err) } if !cmp.Equal(&batch, tc.want, protocmp.Transform()) { @@ -615,6 +919,83 @@ func runGetBatchTest(t *testing.T, client *dataproc.BatchControllerClient, ctx c } } +func javaReq(req map[string]any) map[string]any { + merged := map[string]any{ + "mainClass": "org.apache.spark.examples.SparkPi", + "jarFiles": []string{"file:///usr/lib/spark/examples/jars/spark-examples.jar"}, + } + maps.Copy(merged, req) + return merged +} + +func runCreateSparkBatchTest( + t *testing.T, + client *dataproc.BatchControllerClient, + ctx context.Context, + toolName string, + request map[string]any, + waitForSuccess bool, + validate func(t *testing.T, b *dataprocpb.Batch), +) { + resp, err := invokeTool(toolName, request, nil) + if err != nil { + t.Fatalf("invokeTool failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var body map[string]any + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("error parsing response body: %v", err) + } + + result, ok := body["result"].(string) + if !ok { + t.Fatalf("unable to find result in response body") + } + + var resultMap map[string]any + if err := json.Unmarshal([]byte(result), &resultMap); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + consoleURL, ok := resultMap["consoleUrl"].(string) + if !ok || !strings.HasPrefix(consoleURL, batchURLPrefix) { + t.Errorf("unexpected consoleUrl: %v", consoleURL) + } + logsURL, ok := resultMap["logsUrl"].(string) + if !ok || !strings.HasPrefix(logsURL, logsURLPrefix) { + t.Errorf("unexpected logsUrl: %v", logsURL) + } + metaMap, ok := resultMap["opMetadata"].(map[string]any) + if !ok { + t.Fatalf("unexpected opMetadata: %v", metaMap) + } + metaJson, err := json.Marshal(metaMap) + if err != nil { + t.Fatalf("failed to marshal op metadata to JSON: %s", err) + } + var meta dataprocpb.BatchOperationMetadata + if err := json.Unmarshal([]byte(metaJson), &meta); err != nil { + t.Fatalf("failed to unmarshal result: %v", err) + } + + if validate != nil { + b, err := client.GetBatch(ctx, &dataprocpb.GetBatchRequest{Name: meta.Batch}) + if err != nil { + t.Fatalf("failed to get batch %s: %s", meta.Batch, err) + } + validate(t, b) + } + + if waitForSuccess { + waitForBatch(t, client, ctx, meta.Batch, []dataprocpb.Batch_State{dataprocpb.Batch_SUCCEEDED}, 5*time.Minute) + } +} + func testError(t *testing.T, toolName string, request map[string]any, wantCode int, wantMsg string) { resp, err := invokeTool(toolName, request, nil) if err != nil { diff --git a/tests/spanner/spanner_integration_test.go b/tests/spanner/spanner_integration_test.go index 324738f6cb..4daf87a27e 100644 --- a/tests/spanner/spanner_integration_test.go +++ b/tests/spanner/spanner_integration_test.go @@ -277,7 +277,7 @@ func setupSpannerTable(t *testing.T, ctx context.Context, adminClient *database. // tear down test op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ Database: dbString, - Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)}, + Statements: []string{fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)}, }) if err != nil { t.Errorf("unable to start drop %s operation: %s", tableName, err) @@ -310,7 +310,7 @@ func setupSpannerGraph(t *testing.T, ctx context.Context, adminClient *database. // tear down test op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{ Database: dbString, - Statements: []string{fmt.Sprintf("DROP PROPERTY GRAPH %s", graphName)}, + Statements: []string{fmt.Sprintf("DROP PROPERTY GRAPH IF EXISTS %s", graphName)}, }) if err != nil { t.Errorf("unable to start drop %s operation: %s", graphName, err) diff --git a/tests/tool.go b/tests/tool.go index 3ccd4d4f0b..50335206b2 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -791,6 +791,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti // Default values for MCPTestConfig configs := &MCPTestConfig{ myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`, + mcpSelect1Want: select1Want, supportClientAuth: false, supportSelect1Auth: true, } @@ -920,7 +921,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti }, }, wantStatusCode: http.StatusOK, - wantBody: select1Want, + wantBody: configs.mcpSelect1Want, }, { name: "MCP Invoke my-auth-required-tool with invalid auth token", @@ -1188,7 +1189,7 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)), wantStatusCode: http.StatusOK, - want: `null`, + want: `[]`, }, { name: "invoke list_tables with one existing and one non-existent table", @@ -1259,8 +1260,8 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user } } -func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName, tableName string) func() { - createView := fmt.Sprintf("CREATE VIEW %s AS SELECT name FROM %s", viewName, tableName) +func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, viewName string) func() { + createView := fmt.Sprintf("CREATE VIEW %s AS SELECT 1 AS col", viewName) _, err := pool.Exec(ctx, createView) if err != nil { t.Fatalf("failed to create view: %v", err) @@ -1274,9 +1275,10 @@ func setUpPostgresViews(t *testing.T, ctx context.Context, pool *pgxpool.Pool, v } } -func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string) { - viewName1 := "test_view_1" + strings.ReplaceAll(uuid.New().String(), "-", "") - dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName1, tableName) +func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + //adding this line temporarily + viewName := "test_view_" + strings.ReplaceAll(uuid.New().String(), "-", "") + dropViewfunc1 := setUpPostgresViews(t, ctx, pool, viewName) defer dropViewfunc1() invokeTcs := []struct { @@ -1287,13 +1289,13 @@ func RunPostgresListViewsTest(t *testing.T, ctx context.Context, pool *pgxpool.P }{ { name: "invoke list_views with newly created view", - requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"viewname": "%s"}`, viewName1))), + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"view_name": "%s"}`, viewName))), wantStatusCode: http.StatusOK, - want: fmt.Sprintf(`[{"schemaname":"public","viewname":"%s","viewowner":"postgres"}]`, viewName1), + want: fmt.Sprintf(`[{"schema_name":"public","view_name":"%s","owner_name":"postgres","definition":" SELECT 1 AS col;"}]`, viewName), }, { name: "invoke list_views with non-existent_view", - requestBody: bytes.NewBuffer([]byte(`{"viewname": "non_existent_view"}`)), + requestBody: bytes.NewBuffer([]byte(`{"view_name": "non_existent_view"}`)), wantStatusCode: http.StatusOK, want: `null`, }, @@ -1349,6 +1351,7 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool requestBody io.Reader wantStatusCode int want []map[string]any + compareSubset bool }{ { name: "invoke list_schemas with schema_name", @@ -1356,6 +1359,19 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool wantStatusCode: http.StatusOK, want: []map[string]any{wantSchema}, }, + { + name: "invoke list_schemas with owner name", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"owner": "%s"}`, "postgres"))), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantSchema}, + compareSubset: true, + }, + { + name: "invoke list_schemas with limit 1", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"schema_name": "%s","limit": 1}`, schemaName))), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantSchema}, + }, { name: "invoke list_schemas with non-existent schema", requestBody: bytes.NewBuffer([]byte(`{"schema_name": "non_existent_schema"}`)), @@ -1391,8 +1407,25 @@ func RunPostgresListSchemasTest(t *testing.T, ctx context.Context, pool *pgxpool t.Fatalf("failed to unmarshal nested result string: %v", err) } - if diff := cmp.Diff(tc.want, got); diff != "" { - t.Errorf("Unexpected result (-want +got):\n%s", diff) + if tc.compareSubset { + // Assert that the 'wantTrigger' is present in the 'got' list. + found := false + for _, resultSchema := range got { + if resultSchema["schema_name"] == wantSchema["schema_name"] { + found = true + if diff := cmp.Diff(wantSchema, resultSchema); diff != "" { + t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff) + } + break + } + } + if !found { + t.Errorf("Expected schema '%s' not found in the list of all schemas.", wantSchema) + } + } else { + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Unexpected result (-want +got):\n%s", diff) + } } }) } @@ -1562,12 +1595,14 @@ func RunPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgxpoo requestBody io.Reader wantStatusCode int want []map[string]any + compareSubset bool }{ { name: "list all triggers (expecting the one we created)", requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusOK, want: []map[string]any{wantTrigger}, + compareSubset: true, // avoid test flakiness in race condition }, { name: "filter by trigger_name", @@ -1634,6 +1669,171 @@ func RunPostgresListTriggersTest(t *testing.T, ctx context.Context, pool *pgxpoo t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString) } + if tc.compareSubset { + // Assert that the 'wantTrigger' is present in the 'got' list. + found := false + for _, resultTrigger := range got { + if resultTrigger["trigger_name"] == wantTrigger["trigger_name"] { + found = true + if diff := cmp.Diff(wantTrigger, resultTrigger); diff != "" { + t.Errorf("Mismatch in fields for the expected trigger (-want +got):\n%s", diff) + } + break + } + } + if !found { + t.Errorf("Expected trigger '%s' not found in the list of all triggers.", triggerName) + } + } else { + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Unexpected result (-want +got):\n%s", diff) + } + } + }) + } +} + +func setupPostgresPublicationTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool, tableName string, pubName string) func(t *testing.T) { + t.Helper() + createTableStmt := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName) + if _, err := pool.Exec(ctx, createTableStmt); err != nil { + t.Fatalf("unable to create table %s: %v", tableName, err) + } + + createPubStmt := fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s;", pubName, tableName) + if _, err := pool.Exec(ctx, createPubStmt); err != nil { + if _, dropErr := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); dropErr != nil { + t.Errorf("unable to drop table after failing to create publication: %v", dropErr) + } + t.Fatalf("unable to create publication %s: %v", pubName, err) + } + + return func(t *testing.T) { + t.Helper() + if _, err := pool.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", pubName)); err != nil { + t.Errorf("unable to drop publication %s: %v", pubName, err) + } + if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName)); err != nil { + t.Errorf("unable to drop table %s: %v", tableName, err) + } + } +} + +func RunPostgresListPublicationTablesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + table1Name := "pub_table_1" + pub1Name := "pub_1" + + table2Name := "pub_table_2" + pub2Name := "pub_2" + + cleanup := setupPostgresPublicationTable(t, ctx, pool, table1Name, pub1Name) + defer cleanup(t) + cleanup2 := setupPostgresPublicationTable(t, ctx, pool, table2Name, pub2Name) + defer cleanup2(t) + + // Fetch the current user to match the publication_owner + var currentUser string + err := pool.QueryRow(ctx, "SELECT current_user;").Scan(¤tUser) + if err != nil { + t.Fatalf("unable to fetch current user: %v", err) + } + + wantTable1 := map[string]any{ + "publication_name": pub1Name, + "schema_name": "public", + "table_name": table1Name, + "publishes_all_tables": false, + "publishes_inserts": true, + "publishes_updates": true, + "publishes_deletes": true, + "publishes_truncates": true, + "publication_owner": currentUser, + } + + wantTable2 := map[string]any{ + "publication_name": pub2Name, + "schema_name": "public", + "table_name": table2Name, + "publishes_all_tables": false, + "publishes_inserts": true, + "publishes_updates": true, + "publishes_deletes": true, + "publishes_truncates": true, + "publication_owner": currentUser, + } + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want []map[string]any + }{ + { + name: "list all publication tables", + requestBody: bytes.NewBufferString(`{}`), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantTable1, wantTable2}, + }, + { + name: "list all tables for the created publication", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"publication_names": "%s"}`, pub1Name)), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantTable1}, + }, + { + name: "filter by table_name", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s, %s"}`, table1Name, table2Name)), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantTable1, wantTable2}, + }, + { + name: "filter by schema_name and table_name", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_names": "public", "table_name": "%s , %s"}`, table1Name, table2Name)), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantTable1, wantTable2}, + }, + { + name: "invoke list_publication_tables with non-existent table", + requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`), + wantStatusCode: http.StatusOK, + want: nil, + }, + { + name: "invoke list_publication_tables with non-existent publication", + requestBody: bytes.NewBufferString(`{"publication_names": "non_existent_pub"}`), + wantStatusCode: http.StatusOK, + want: nil, + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_publication_tables/invoke" + + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got []map[string]any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal nested result string: %v, content: %s", err, resultString) + } + if diff := cmp.Diff(tc.want, got); diff != "" { t.Errorf("Unexpected result (-want +got):\n%s", diff) } @@ -2011,15 +2211,15 @@ func RunPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpo defer teardown(t) wantSequence := map[string]any{ - "sequencename": sequenceName, - "schemaname": "public", - "sequenceowner": "postgres", - "data_type": "bigint", - "start_value": float64(1), - "min_value": float64(1), - "max_value": float64(9223372036854775807), - "increment_by": float64(1), - "last_value": nil, + "sequence_name": sequenceName, + "schema_name": "public", + "sequence_owner": "postgres", + "data_type": "bigint", + "start_value": float64(1), + "min_value": float64(1), + "max_value": float64(9223372036854775807), + "increment_by": float64(1), + "last_value": nil, } invokeTcs := []struct { @@ -2031,13 +2231,13 @@ func RunPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpo }{ { name: "invoke list_sequences", - requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sequencename": "%s"}`, sequenceName)), + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sequence_name": "%s"}`, sequenceName)), wantStatusCode: http.StatusOK, want: []map[string]any{wantSequence}, }, { name: "invoke list_sequences with non-existent sequence", - requestBody: bytes.NewBufferString(`{"sequencename": "non_existent_sequence"}`), + requestBody: bytes.NewBufferString(`{"sequence_name": "non_existent_sequence"}`), wantStatusCode: http.StatusOK, want: nil, }, @@ -2077,8 +2277,450 @@ func RunPostgresListSequencesTest(t *testing.T, ctx context.Context, pool *pgxpo } } +func RunPostgresListTableSpacesTest(t *testing.T) { + invokeTcs := []struct { + name string + api string + requestBody io.Reader + wantStatusCode int + }{ + { + name: "invoke list_tablespaces output", + api: "http://127.0.0.1:5000/api/tool/list_tablespaces/invoke", + wantStatusCode: http.StatusOK, + requestBody: bytes.NewBuffer([]byte(`{}`)), + }, + } + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + resp, respBody := RunRequest(t, http.MethodPost, tc.api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody)) + } + + // Intentionally not adding the output check as output depends on the postgres instance used where the the functional test runs. + // Adding the check will make the test flaky. + }) + } +} + +func RunPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + targetSetting := "maintenance_work_mem" + var name, setting, unit, shortDesc, source, contextVal string + + // We query the raw pg_settings to get the data needed to reconstruct the logic + // defined in your listPgSettingQuery. + err := pool.QueryRow(ctx, ` + SELECT name, setting, unit, short_desc, source, context + FROM pg_settings + WHERE name = $1 + `, targetSetting).Scan(&name, &setting, &unit, &shortDesc, &source, &contextVal) + + if err != nil { + t.Fatalf("Setup failed: could not fetch postgres setting '%s': %v", targetSetting, err) + } + + // Replicate the SQL CASE logic for 'requires_restart' field + requiresRestart := "No" + switch contextVal { + case "postmaster": + requiresRestart = "Yes" + case "sighup": + requiresRestart = "No (Reload sufficient)" + } + + expectedObject := map[string]interface{}{ + "name": name, + "current_value": setting, + "unit": unit, + "short_desc": shortDesc, + "source": source, + "requires_restart": requiresRestart, + } + expectedJSON, _ := json.Marshal([]interface{}{expectedObject}) + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want string + }{ + { + name: "invoke list_pg_settings with specific setting", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"setting_name": "%s"}`, targetSetting))), + wantStatusCode: http.StatusOK, + want: string(expectedJSON), + }, + { + name: "invoke list_pg_settings with non-existent setting", + requestBody: bytes.NewBuffer([]byte(`{"setting_name": "non_existent_config_xyz"}`)), + wantStatusCode: http.StatusOK, + want: `null`, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_pg_settings/invoke" + resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(body, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got, want any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal nested result string: %v", err) + } + if err := json.Unmarshal([]byte(tc.want), &want); err != nil { + t.Fatalf("failed to unmarshal want string: %v", err) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +// RunPostgresDatabaseStatsTest tests the database_stats tool by comparing API results +// against a direct query to the database. +func RunPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + dbName1 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbOwner1 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbName2 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "") + dbOwner2 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "") + + cleanup1 := setUpDatabase(t, ctx, pool, dbName1, dbOwner1) + defer cleanup1() + cleanup2 := setUpDatabase(t, ctx, pool, dbName2, dbOwner2) + defer cleanup2() + + requiredKeys := map[string]bool{ + "database_name": true, + "database_owner": true, + "default_tablespace": true, + "is_connectable": true, + } + + db1Want := map[string]interface{}{ + "database_name": dbName1, + "database_owner": dbOwner1, + "default_tablespace": "pg_default", + "is_connectable": true, + } + + db2Want := map[string]interface{}{ + "database_name": dbName2, + "database_owner": dbOwner2, + "default_tablespace": "pg_default", + "is_connectable": true, + } + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want []map[string]interface{} + }{ + { + name: "invoke database_stats filtering by specific database name", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_name": "%s"}`, dbName1))), + wantStatusCode: http.StatusOK, + want: []map[string]interface{}{db1Want}, + }, + { + name: "invoke database_stats filtering by specific owner", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"database_owner": "%s"}`, dbOwner2))), + wantStatusCode: http.StatusOK, + want: []map[string]interface{}{db2Want}, + }, + { + name: "filter by tablespace", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"default_tablespace": "pg_default", "database_name": "%s"}`, dbName1))), + wantStatusCode: http.StatusOK, + want: []map[string]interface{}{db1Want}, + }, + { + name: "sort by size", + requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"sort_by": "size", "database_name": "%s"}`, dbName2))), + wantStatusCode: http.StatusOK, + want: []map[string]interface{}{db2Want}, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_database_stats/invoke" + resp, body := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body)) + } + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(body, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got []map[string]interface{} + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal nested result string: %v", err) + } + + // Configuration for comparison + opts := []cmp.Option{ + // Ensure consistent order based on name for comparison + cmpopts.SortSlices(func(a, b map[string]interface{}) bool { + return a["database_name"].(string) < b["database_name"].(string) + }), + + // Ignore Volatile Keys which change in every run and only compare the keys in 'requiredKeys' + cmpopts.IgnoreMapEntries(func(key string, _ interface{}) bool { + return !requiredKeys[key] + }), + + // Ignore Irrelevant Databases + cmpopts.IgnoreSliceElements(func(v map[string]interface{}) bool { + name, ok := v["database_name"].(string) + if !ok { + return true + } + return name != dbName1 && name != dbName2 + }), + } + + if diff := cmp.Diff(tc.want, got, opts...); diff != "" { + t.Errorf("Unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func setUpDatabase(t *testing.T, ctx context.Context, pool *pgxpool.Pool, dbName, dbOwner string) func() { + _, err := pool.Exec(ctx, fmt.Sprintf("CREATE ROLE %s LOGIN PASSWORD 'password';", dbOwner)) + if err != nil { + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE %s;", dbOwner)) + t.Fatalf("failed to create %s: %v", dbOwner, err) + } + _, err = pool.Exec(ctx, fmt.Sprintf("GRANT %s TO current_user;", dbOwner)) + if err != nil { + t.Fatalf("failed to grant %s to current_user: %v", dbOwner, err) + } + _, err = pool.Exec(ctx, fmt.Sprintf("CREATE DATABASE %s OWNER %s;", dbName, dbOwner)) + if err != nil { + t.Fatalf("failed to create %s: %v", dbName, err) + } + return func() { + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s;", dbName)) + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", dbOwner)) + } +} + +func setupPostgresRoles(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, string, string, func(t *testing.T)) { + t.Helper() + suffix := strings.ReplaceAll(uuid.New().String(), "-", "") + + adminUser := "test_role_admin_" + suffix + superUser := "test_role_super_" + suffix + normalUser := "test_role_normal_" + suffix + + createAdminStmt := fmt.Sprintf("CREATE ROLE %s NOLOGIN;", adminUser) + if _, err := pool.Exec(ctx, createAdminStmt); err != nil { + t.Fatalf("unable to create role %s: %v", adminUser, err) + } + + createSuperUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN CREATEDB;", superUser) + if _, err := pool.Exec(ctx, createSuperUserStmt); err != nil { + t.Fatalf("unable to create role %s: %v", superUser, err) + } + + createNormalUserStmt := fmt.Sprintf("CREATE ROLE %s LOGIN;", normalUser) + if _, err := pool.Exec(ctx, createNormalUserStmt); err != nil { + t.Fatalf("unable to create role %s: %v", normalUser, err) + } + + // Establish Relationships (Admin -> Superuser -> Normal) + if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", adminUser, superUser)); err != nil { + t.Fatalf("unable to grant %s to %s: %v", adminUser, superUser, err) + } + if _, err := pool.Exec(ctx, fmt.Sprintf("GRANT %s TO %s;", superUser, normalUser)); err != nil { + t.Fatalf("unable to grant %s to %s: %v", superUser, normalUser, err) + } + + return adminUser, superUser, normalUser, func(t *testing.T) { + t.Helper() + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", normalUser)) + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", superUser)) + _, _ = pool.Exec(ctx, fmt.Sprintf("DROP ROLE IF EXISTS %s;", adminUser)) + } +} + +func RunPostgresListRolesTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + adminUser, superUser, normalUser, cleanup := setupPostgresRoles(t, ctx, pool) + defer cleanup(t) + + wantAdmin := map[string]any{ + "role_name": adminUser, + "connection_limit": float64(-1), + "is_superuser": false, + "inherits_privileges": true, + "can_create_roles": false, + "can_create_db": false, + "can_login": false, + "is_replication_role": false, + "bypass_rls": false, + "direct_members": []any{superUser}, + "member_of": []any{}, + } + + wantSuperUser := map[string]any{ + "role_name": superUser, + "connection_limit": float64(-1), + "is_superuser": false, + "inherits_privileges": true, + "can_create_roles": false, + "can_create_db": true, + "can_login": true, + "is_replication_role": false, + "bypass_rls": false, + "direct_members": []any{normalUser}, + "member_of": []any{adminUser}, + } + + wantNormalUser := map[string]any{ + "role_name": normalUser, + "connection_limit": float64(-1), + "is_superuser": false, + "inherits_privileges": true, + "can_create_roles": false, + "can_create_db": false, + "can_login": true, + "is_replication_role": false, + "bypass_rls": false, + "direct_members": []any{}, + "member_of": []any{superUser}, + } + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + want []map[string]any + }{ + { + name: "list_roles with filter for created roles", + requestBody: bytes.NewBufferString(`{"role_name": "test_role_"}`), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantAdmin, wantNormalUser, wantSuperUser}, + }, + { + name: "list_roles filter specific role", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"role_name": "%s"}`, superUser)), + wantStatusCode: http.StatusOK, + want: []map[string]any{wantSuperUser}, + }, + { + name: "list_roles non-existent role", + requestBody: bytes.NewBufferString(`{"role_name": "non_existent_role_xyz"}`), + wantStatusCode: http.StatusOK, + want: nil, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_roles/invoke" + + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got []map[string]any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString) + } + + gotMap := make(map[string]map[string]any) + for _, role := range got { + // Remove fields that change every run + delete(role, "oid") + delete(role, "valid_until") + + if name, ok := role["role_name"].(string); ok { + gotMap[name] = role + } + } + + // Check that every role in 'want' exists in 'got' and matches + for _, wantRole := range tc.want { + roleName, _ := wantRole["role_name"].(string) + + gotRole, exists := gotMap[roleName] + if !exists { + t.Errorf("Expected role %q was not found in the response", roleName) + continue + } + + if diff := cmp.Diff(wantRole, gotRole); diff != "" { + t.Errorf("Role %q mismatch (-want +got):\n%s", roleName, diff) + } + } + + // Verify that if want is nil/empty, got is also empty + if len(tc.want) == 0 && len(got) != 0 { + t.Errorf("Expected empty result, but got %d roles", len(got)) + } + }) + } +} + // RunMySQLListTablesTest run tests against the mysql-list-tables tool -func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) { +func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNameAuth, expectedOwner string) { + var ownerWant any + if expectedOwner == "" { + ownerWant = nil + } else { + ownerWant = expectedOwner + } + type tableInfo struct { ObjectName string `json:"object_name"` SchemaName string `json:"schema_name"` @@ -2110,6 +2752,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam ObjectName: tableNameParam, SchemaName: databaseName, ObjectType: "TABLE", + Owner: ownerWant, Columns: []column{ {DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1}, {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2}, @@ -2123,6 +2766,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam ObjectName: tableNameAuth, SchemaName: databaseName, ObjectType: "TABLE", + Owner: ownerWant, Columns: []column{ {DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1}, {DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2}, @@ -2177,7 +2821,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam name: "invoke list_tables with non-existent table", requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`), wantStatusCode: http.StatusOK, - want: nil, + want: []objectDetails{}, }, } for _, tc := range invokeTcs { @@ -2209,7 +2853,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam if err := json.Unmarshal([]byte(resultString), &tables); err != nil { t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } - var details []map[string]any + details := []map[string]any{} for _, table := range tables { var d map[string]any if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { @@ -2219,23 +2863,19 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam } got = details } else { - if resultString == "null" { - got = nil - } else { - var tables []tableInfo - if err := json.Unmarshal([]byte(resultString), &tables); err != nil { - t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) - } - var details []objectDetails - for _, table := range tables { - var d objectDetails - if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { - t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) - } - details = append(details, d) - } - got = details + var tables []tableInfo + if err := json.Unmarshal([]byte(resultString), &tables); err != nil { + t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err) } + details := []objectDetails{} + for _, table := range tables { + var d objectDetails + if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil { + t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err) + } + details = append(details, d) + } + got = details } opts := []cmp.Option{ @@ -2246,7 +2886,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam // Checking only the current database where the test tables are created to avoid brittle tests. if tc.isAllTables { - var filteredGot []objectDetails + filteredGot := []objectDetails{} if got != nil { for _, item := range got.([]objectDetails) { if item.SchemaName == databaseName { @@ -2254,11 +2894,7 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam } } } - if len(filteredGot) == 0 { - got = nil - } else { - got = filteredGot - } + got = filteredGot } if diff := cmp.Diff(tc.want, got, opts...); diff != "" { @@ -2740,6 +3376,81 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar } } +func RunMySQLGetQueryPlanTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName, tableNameParam string) { + // Create a simple query to explain + query := fmt.Sprintf("SELECT * FROM %s", tableNameParam) + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + checkResult func(t *testing.T, result any) + }{ + { + name: "invoke get_query_plan with valid query", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"sql_statement": "%s"}`, query)), + wantStatusCode: http.StatusOK, + checkResult: func(t *testing.T, result any) { + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("result should be a map, got %T", result) + } + if _, ok := resultMap["query_block"]; !ok { + t.Errorf("result should contain 'query_block', got %v", resultMap) + } + }, + }, + { + name: "invoke get_query_plan with invalid query", + requestBody: bytes.NewBufferString(`{"sql_statement": "SELECT * FROM non_existent_table"}`), + wantStatusCode: http.StatusBadRequest, + checkResult: nil, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/get_query_plan/invoke" + resp, respBytes := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBytes)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper map[string]json.RawMessage + + if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil { + t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes)) + } + + resultJSON, ok := bodyWrapper["result"] + if !ok { + t.Fatal("unable to find 'result' in response body") + } + + var resultString string + if err := json.Unmarshal(resultJSON, &resultString); err != nil { + if string(resultJSON) == "null" { + resultString = "null" + } else { + t.Fatalf("'result' is not a JSON-encoded string: %s", err) + } + } + + var got map[string]any + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal actual result string: %v", err) + } + + if tc.checkResult != nil { + tc.checkResult(t, got) + } + }) + } +} + // RunMSSQLListTablesTest run tests againsts the mssql-list-tables tools. func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) { // TableNameParam columns to construct want. @@ -2846,7 +3557,7 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) api: "http://127.0.0.1:5000/api/tool/list_tables/invoke", requestBody: `{"table_names": "non_existent_table"}`, wantStatusCode: http.StatusOK, - want: `null`, + want: `[]`, }, { name: "invoke list_tables with one existing and one non-existent table", @@ -3435,6 +4146,250 @@ func RunPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *pgxp } } +// RunPostgresListTableStatsTest runs tests for the postgres list-table-stats tool +func RunPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) { + type tableStatsDetails struct { + SchemaName string `json:"schema_name"` + TableName string `json:"table_name"` + Owner string `json:"owner"` + TotalSizeBytes any `json:"total_size_bytes"` + SeqScan any `json:"seq_scan"` + IdxScan any `json:"idx_scan"` + IdxScanRatioPercent float64 `json:"idx_scan_ratio_percent"` + LiveRows any `json:"live_rows"` + DeadRows any `json:"dead_rows"` + DeadRowRatioPercent float64 `json:"dead_row_ratio_percent"` + NTupIns any `json:"n_tup_ins"` + NTupUpd any `json:"n_tup_upd"` + NTupDel any `json:"n_tup_del"` + LastVacuum any `json:"last_vacuum"` + LastAutovacuum any `json:"last_autovacuum"` + LastAutoanalyze any `json:"last_autoanalyze"` + } + + // Create a test table to generate statistics + testTableName := "test_list_table_stats_" + strings.ReplaceAll(uuid.New().String(), "-", "") + createTableStmt := fmt.Sprintf(` + CREATE TABLE %s ( + id SERIAL PRIMARY KEY, + name VARCHAR(100), + email VARCHAR(100) + ) + `, testTableName) + + if _, err := pool.Exec(ctx, createTableStmt); err != nil { + t.Fatalf("unable to create test table: %s", err) + } + defer func() { + dropTableStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", testTableName) + if _, err := pool.Exec(ctx, dropTableStmt); err != nil { + t.Logf("warning: unable to drop test table: %v", err) + } + }() + + // Insert some data to generate statistics + insertStmt := fmt.Sprintf(` + INSERT INTO %s (name, email) VALUES + ('Alice', 'alice@example.com'), + ('Bob', 'bob@example.com'), + ('Charlie', 'charlie@example.com'), + ('David', 'david@example.com'), + ('Eve', 'eve@example.com') + `, testTableName) + + if _, err := pool.Exec(ctx, insertStmt); err != nil { + t.Fatalf("unable to insert test data: %s", err) + } + + // Run some sequential scans to generate statistics + for i := 0; i < 3; i++ { + selectStmt := fmt.Sprintf("SELECT * FROM %s WHERE name = 'Alice'", testTableName) + if _, err := pool.Exec(ctx, selectStmt); err != nil { + t.Logf("warning: unable to execute select: %v", err) + } + } + + // Run ANALYZE to update statistics + analyzeStmt := fmt.Sprintf("ANALYZE %s", testTableName) + if _, err := pool.Exec(ctx, analyzeStmt); err != nil { + t.Logf("warning: unable to run ANALYZE: %v", err) + } + + invokeTcs := []struct { + name string + requestBody io.Reader + wantStatusCode int + shouldHaveData bool + filterTable bool + }{ + { + name: "list table stats with no arguments (default limit)", + requestBody: bytes.NewBufferString(`{}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, // may or may not have data depending on what's in the database + }, + { + name: "list table stats with default limit", + requestBody: bytes.NewBufferString(`{"schema_name": "public"}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats filtering by specific table", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_name": "%s"}`, testTableName)), + wantStatusCode: http.StatusOK, + shouldHaveData: true, + filterTable: true, + }, + { + name: "list table stats with custom limit", + requestBody: bytes.NewBufferString(`{"limit": 10}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats sorted by size", + requestBody: bytes.NewBufferString(`{"sort_by": "size", "limit": 5}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats sorted by seq_scan", + requestBody: bytes.NewBufferString(`{"sort_by": "seq_scan", "limit": 5}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats sorted by idx_scan", + requestBody: bytes.NewBufferString(`{"sort_by": "idx_scan", "limit": 5}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats sorted by dead_rows", + requestBody: bytes.NewBufferString(`{"sort_by": "dead_rows", "limit": 5}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats with non-existent table filter", + requestBody: bytes.NewBufferString(`{"table_name": "non_existent_table_xyz"}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats with non-existent schema filter", + requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema_xyz"}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + { + name: "list table stats with owner filter", + requestBody: bytes.NewBufferString(`{"owner": "postgres"}`), + wantStatusCode: http.StatusOK, + shouldHaveData: false, + }, + } + + for _, tc := range invokeTcs { + t.Run(tc.name, func(t *testing.T) { + const api = "http://127.0.0.1:5000/api/tool/list_table_stats/invoke" + resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil) + if resp.StatusCode != tc.wantStatusCode { + t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody)) + } + if tc.wantStatusCode != http.StatusOK { + return + } + + var bodyWrapper struct { + Result json.RawMessage `json:"result"` + } + if err := json.Unmarshal(respBody, &bodyWrapper); err != nil { + t.Fatalf("error decoding response wrapper: %v", err) + } + + var resultString string + if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil { + resultString = string(bodyWrapper.Result) + } + + var got []tableStatsDetails + if resultString != "null" { + if err := json.Unmarshal([]byte(resultString), &got); err != nil { + t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString) + } + } + + // Verify expected data presence + if tc.shouldHaveData { + if len(got) == 0 { + t.Fatalf("expected data but got empty result") + } + + // Verify the test table is in results + found := false + for _, row := range got { + if row.TableName == testTableName { + found = true + // Verify expected fields are present + if row.SchemaName == "" { + t.Errorf("schema_name should not be empty") + } + if row.Owner == "" { + t.Errorf("owner should not be empty") + } + if row.TotalSizeBytes == nil { + t.Errorf("total_size_bytes should not be null") + } + if row.LiveRows == nil { + t.Errorf("live_rows should not be null") + } + break + } + } + + if !found { + t.Errorf("test table %s not found in results", testTableName) + } + } else if tc.filterTable { + // For filtered queries that shouldn't find anything + if len(got) != 0 { + t.Logf("warning: expected no data but got: %v", len(got)) + } + } + + // Verify result structure and data types + for _, stat := range got { + // Verify schema_name and table_name are strings + if stat.SchemaName == "" && stat.TableName != "" { + t.Errorf("schema_name is empty for table %s", stat.TableName) + } + + // Verify numeric fields are valid + if stat.IdxScanRatioPercent < 0 || stat.IdxScanRatioPercent > 100 { + t.Errorf("idx_scan_ratio_percent should be between 0 and 100, got %f", stat.IdxScanRatioPercent) + } + + if stat.DeadRowRatioPercent < 0 || stat.DeadRowRatioPercent > 100 { + t.Errorf("dead_row_ratio_percent should be between 0 and 100, got %f", stat.DeadRowRatioPercent) + } + } + + // Verify sorting for specific sort_by options + if tc.name == "list table stats sorted by size" && len(got) > 1 { + for i := 0; i < len(got)-1; i++ { + current, ok1 := got[i].TotalSizeBytes.(float64) + next, ok2 := got[i+1].TotalSizeBytes.(float64) + if ok1 && ok2 && current < next { + t.Logf("warning: results may not be sorted by total_size_bytes descending") + } + } + } + }) + } +} + // RunRequest is a helper function to send HTTP requests and return the response func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) { // Send request