diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 14742514bc..4b78f6b3f0 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -875,8 +875,8 @@ steps: total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}') echo "Oracle total coverage: $total_coverage" coverage_numeric=$(echo "$total_coverage" | sed 's/%//') - if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then - echo "Coverage failure: $total_coverage is below 30%." + if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 20)}'; then + echo "Coverage failure: $total_coverage is below 20%." exit 1 fi diff --git a/.github/workflows/deploy_dev_docs.yaml b/.github/workflows/deploy_dev_docs.yaml index 0eee9a4330..add4a149cd 100644 --- a/.github/workflows/deploy_dev_docs.yaml +++ b/.github/workflows/deploy_dev_docs.yaml @@ -40,7 +40,7 @@ jobs: group: docs-deployment cancel-in-progress: false steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: fetch-depth: 0 # Fetch all history for .GitInfo and .Lastmod @@ -51,12 +51,12 @@ jobs: extended: true - name: Setup Node - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6 + uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6 with: node-version: "22" - name: Cache dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} diff --git a/.github/workflows/deploy_previous_version_docs.yaml b/.github/workflows/deploy_previous_version_docs.yaml index b11bf13138..b792d38daa 100644 --- a/.github/workflows/deploy_previous_version_docs.yaml +++ b/.github/workflows/deploy_previous_version_docs.yaml @@ -30,14 +30,14 @@ jobs: steps: - name: Checkout main branch (for latest templates and theme) - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: 'main' submodules: 'recursive' fetch-depth: 0 - name: Checkout old content from tag into a temporary directory - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: ${{ github.event.inputs.version_tag }} path: 'old_version_source' # Checkout into a temp subdir @@ -57,7 +57,7 @@ jobs: with: hugo-version: "0.145.0" extended: true - - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6 + - uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6 with: node-version: "22" diff --git a/.github/workflows/deploy_versioned_docs.yaml b/.github/workflows/deploy_versioned_docs.yaml index 5c23b51994..a0e3416f08 100644 --- a/.github/workflows/deploy_versioned_docs.yaml +++ b/.github/workflows/deploy_versioned_docs.yaml @@ -30,7 +30,7 @@ jobs: cancel-in-progress: false steps: - name: Checkout Code at Tag - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: ${{ github.event.release.tag_name }} @@ -44,7 +44,7 @@ jobs: extended: true - name: Setup Node - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6 + uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6 with: node-version: "22" diff --git a/.github/workflows/docs_preview_clean.yaml b/.github/workflows/docs_preview_clean.yaml index 5dc6070aa7..ba44bfcc8b 100644 --- a/.github/workflows/docs_preview_clean.yaml +++ b/.github/workflows/docs_preview_clean.yaml @@ -34,7 +34,7 @@ jobs: group: "preview-${{ github.event.number }}" cancel-in-progress: true steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: ref: versioned-gh-pages diff --git a/.github/workflows/docs_preview_deploy.yaml b/.github/workflows/docs_preview_deploy.yaml index 1e72e69a30..05721dc8a2 100644 --- a/.github/workflows/docs_preview_deploy.yaml +++ b/.github/workflows/docs_preview_deploy.yaml @@ -49,7 +49,7 @@ jobs: group: "preview-${{ github.event.number }}" cancel-in-progress: true steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 with: # Checkout the PR's HEAD commit (supports forks). ref: ${{ github.event.pull_request.head.sha }} @@ -62,12 +62,12 @@ jobs: extended: true - name: Setup Node - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6 + uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # v6 with: node-version: "22" - name: Cache dependencies - uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} diff --git a/.github/workflows/link_checker_workflow.yaml b/.github/workflows/link_checker_workflow.yaml index 4296a122e7..4558f0fa42 100644 --- a/.github/workflows/link_checker_workflow.yaml +++ b/.github/workflows/link_checker_workflow.yaml @@ -22,17 +22,17 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - name: Restore lychee cache - uses: actions/cache@v4 + uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5 with: path: .lycheecache key: cache-lychee-${{ github.sha }} restore-keys: cache-lychee- - name: Link Checker - uses: lycheeverse/lychee-action@v2 + uses: lycheeverse/lychee-action@a8c4c7cb88f0c7386610c35eb25108e448569cb0 # v2 with: args: > --verbose diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 59b32d432d..0870637648 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -55,7 +55,7 @@ jobs: with: go-version: "1.25" - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{ github.event.pull_request.head.repo.full_name }} @@ -66,7 +66,7 @@ jobs: run: | go mod tidy && git diff --exit-code - name: golangci-lint - uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0 + uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9.2.0 with: version: latest args: --timeout 10m diff --git a/.github/workflows/publish-mcp.yml b/.github/workflows/publish-mcp.yml index 34d29a0960..dc84fbb759 100644 --- a/.github/workflows/publish-mcp.yml +++ b/.github/workflows/publish-mcp.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6 - name: Wait for image in Artifact Registry shell: bash diff --git a/.github/workflows/sync-labels.yaml b/.github/workflows/sync-labels.yaml index ef6842fcb2..2a0d392497 100644 --- a/.github/workflows/sync-labels.yaml +++ b/.github/workflows/sync-labels.yaml @@ -29,7 +29,7 @@ jobs: issues: 'write' pull-requests: 'write' steps: - - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 - uses: micnncim/action-label-syncer@3abd5ab72fda571e69fffd97bd4e0033dd5f495c # v1.3.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a225afa266..c11f7f388c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -62,7 +62,7 @@ jobs: go-version: "1.24" - name: Checkout code - uses: actions/checkout@93cb6efe18208431cddfb8368fd83d5badbf9bfd # v5.0.1 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{ github.event.pull_request.head.repo.full_name }} diff --git a/docs/en/getting-started/quickstart/go/adkgo/go.mod b/docs/en/getting-started/quickstart/go/adkgo/go.mod index 84bf3dad72..c56ff97cc5 100644 --- a/docs/en/getting-started/quickstart/go/adkgo/go.mod +++ b/docs/en/getting-started/quickstart/go/adkgo/go.mod @@ -28,11 +28,11 @@ require ( go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/crypto v0.43.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect google.golang.org/api v0.255.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect google.golang.org/grpc v1.76.0 // indirect diff --git a/docs/en/getting-started/quickstart/go/adkgo/go.sum b/docs/en/getting-started/quickstart/go/adkgo/go.sum index 02284fbc2f..018bd7961a 100644 --- a/docs/en/getting-started/quickstart/go/adkgo/go.sum +++ b/docs/en/getting-started/quickstart/go/adkgo/go.sum @@ -88,18 +88,18 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/docs/en/getting-started/quickstart/go/genkit/go.mod b/docs/en/getting-started/quickstart/go/genkit/go.mod index 0e323f53ad..41300f5f89 100644 --- a/docs/en/getting-started/quickstart/go/genkit/go.mod +++ b/docs/en/getting-started/quickstart/go/genkit/go.mod @@ -39,11 +39,11 @@ require ( go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/crypto v0.43.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect google.golang.org/api v0.255.0 // indirect google.golang.org/genai v1.34.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect diff --git a/docs/en/getting-started/quickstart/go/genkit/go.sum b/docs/en/getting-started/quickstart/go/genkit/go.sum index 4acd085d92..affe1b3a85 100644 --- a/docs/en/getting-started/quickstart/go/genkit/go.sum +++ b/docs/en/getting-started/quickstart/go/genkit/go.sum @@ -123,18 +123,18 @@ go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJr go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/docs/en/getting-started/quickstart/go/openAI/go.mod b/docs/en/getting-started/quickstart/go/openAI/go.mod index 96e7ad01df..dddc82b303 100644 --- a/docs/en/getting-started/quickstart/go/openAI/go.mod +++ b/docs/en/getting-started/quickstart/go/openAI/go.mod @@ -26,11 +26,11 @@ require ( go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/metric v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect - golang.org/x/crypto v0.43.0 // indirect - golang.org/x/net v0.46.0 // indirect + golang.org/x/crypto v0.45.0 // indirect + golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect google.golang.org/api v0.255.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect google.golang.org/grpc v1.76.0 // indirect diff --git a/docs/en/getting-started/quickstart/go/openAI/go.sum b/docs/en/getting-started/quickstart/go/openAI/go.sum index 28ff351c85..633df25e63 100644 --- a/docs/en/getting-started/quickstart/go/openAI/go.sum +++ b/docs/en/getting-started/quickstart/go/openAI/go.sum @@ -94,18 +94,18 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6 go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.32.0 h1:jsCblLleRMDrxMN29H3z/k1KliIvpLgCkE6R8FXXNgY= golang.org/x/oauth2 v0.32.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= diff --git a/docs/en/getting-started/quickstart/js/adk/package-lock.json b/docs/en/getting-started/quickstart/js/adk/package-lock.json index d921ceb367..84bc88e40a 100644 --- a/docs/en/getting-started/quickstart/js/adk/package-lock.json +++ b/docs/en/getting-started/quickstart/js/adk/package-lock.json @@ -18,7 +18,6 @@ "resolved": "https://registry.npmjs.org/@google-cloud/paginator/-/paginator-5.0.2.tgz", "integrity": "sha512-DJS3s0OVH4zFDB1PzjxAsHqJT6sKVbRwwML0ZBP9PbU7Yebtu/7SWMRzvO2J3nUi9pRNITCfu4LJeooM2w4pjg==", "license": "Apache-2.0", - "peer": true, "dependencies": { "arrify": "^2.0.0", "extend": "^3.0.2" @@ -32,7 +31,6 @@ "resolved": "https://registry.npmjs.org/@google-cloud/projectify/-/projectify-4.0.0.tgz", "integrity": "sha512-MmaX6HeSvyPbWGwFq7mXdo0uQZLGBYCwziiLIGq5JVX+/bdI3SAq6bP98trV5eTWfLuvsMcIC1YJOF2vfteLFA==", "license": "Apache-2.0", - "peer": true, "engines": { "node": ">=14.0.0" } @@ -42,7 +40,6 @@ "resolved": "https://registry.npmjs.org/@google-cloud/promisify/-/promisify-4.0.0.tgz", "integrity": "sha512-Orxzlfb9c67A15cq2JQEyVc7wEsmFBmHjZWZYQMUyJ1qivXyMwdyNOs9odi79hze+2zqdTtu1E19IM/FtqZ10g==", "license": "Apache-2.0", - "peer": true, "engines": { "node": ">=14" } @@ -52,7 +49,6 @@ "resolved": "https://registry.npmjs.org/@google-cloud/storage/-/storage-7.18.0.tgz", "integrity": "sha512-r3ZwDMiz4nwW6R922Z1pwpePxyRwE5GdevYX63hRmAQUkUQJcBH/79EnQPDv5cOv1mFBgevdNWQfi3tie3dHrQ==", "license": "Apache-2.0", - "peer": true, "dependencies": { "@google-cloud/paginator": "^5.0.0", "@google-cloud/projectify": "^4.0.0", @@ -79,7 +75,6 @@ "resolved": "https://registry.npmjs.org/uuid/-/uuid-8.3.2.tgz", "integrity": "sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==", "license": "MIT", - "peer": true, "bin": { "uuid": "dist/bin/uuid" } @@ -102,6 +97,7 @@ "resolved": "https://registry.npmjs.org/@google/genai/-/genai-1.14.0.tgz", "integrity": "sha512-jirYprAAJU1svjwSDVCzyVq+FrJpJd5CSxR/g2Ga/gZ0ZYZpcWjMS75KJl9y71K1mDN+tcx6s21CzCbB2R840g==", "license": "Apache-2.0", + "peer": true, "dependencies": { "google-auth-library": "^9.14.2", "ws": "^8.18.0" @@ -140,6 +136,7 @@ "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.5.tgz", "integrity": "sha512-QakrKIGniGuRVfWBdMsDea/dx1PNE739QJ7gCM41s9q+qaCYTHCdsIBXQVVXry3mfWAiaM9kT22Hyz53Uw8mfg==", "license": "MIT", + "peer": true, "dependencies": { "ajv": "^6.12.6", "content-type": "^1.0.5", @@ -302,7 +299,6 @@ "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-2.0.0.tgz", "integrity": "sha512-XCuKFP5PS55gnMVu3dty8KPatLqUoy/ZYzDzAGCQ8JNFCkLXzmI7vNHCR+XpbZaMWQK/vQubr7PkYq8g470J/A==", "license": "MIT", - "peer": true, "engines": { "node": ">= 10" } @@ -311,15 +307,13 @@ "version": "0.12.5", "resolved": "https://registry.npmjs.org/@types/caseless/-/caseless-0.12.5.tgz", "integrity": "sha512-hWtVTC2q7hc7xZ/RLbxapMvDMgUnDvKvMOpKal4DrMyfGBUfB1oKaZlIRr6mJL+If3bAP6sV/QneGzF6tJjZDg==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/@types/node": { "version": "24.10.1", "resolved": "https://registry.npmjs.org/@types/node/-/node-24.10.1.tgz", "integrity": "sha512-GNWcUTRBgIRJD5zj+Tq0fKOJ5XZajIiBroOF0yvj2bSU1WvNdYS/dn9UxwsujGW4JX06dnHyjV2y9rRaybH0iQ==", "license": "MIT", - "peer": true, "dependencies": { "undici-types": "~7.16.0" } @@ -329,7 +323,6 @@ "resolved": "https://registry.npmjs.org/@types/request/-/request-2.48.13.tgz", "integrity": "sha512-FGJ6udDNUCjd19pp0Q3iTiDkwhYup7J8hpMW9c4k53NrccQFFWKRho6hvtPPEhnXWKvukfwAlB6DbDz4yhH5Gg==", "license": "MIT", - "peer": true, "dependencies": { "@types/caseless": "*", "@types/node": "*", @@ -342,7 +335,6 @@ "resolved": "https://registry.npmjs.org/form-data/-/form-data-2.5.5.tgz", "integrity": "sha512-jqdObeR2rxZZbPSGL+3VckHMYtu+f9//KXBsVny6JSX/pa38Fy+bGjuG8eW/H6USNQWhLi8Num++cU2yOCNz4A==", "license": "MIT", - "peer": true, "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", @@ -360,7 +352,6 @@ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", "license": "MIT", - "peer": true, "engines": { "node": ">= 0.6" } @@ -370,7 +361,6 @@ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", "license": "MIT", - "peer": true, "dependencies": { "mime-db": "1.52.0" }, @@ -382,15 +372,13 @@ "version": "4.0.5", "resolved": "https://registry.npmjs.org/@types/tough-cookie/-/tough-cookie-4.0.5.tgz", "integrity": "sha512-/Ad8+nIOV7Rl++6f1BdKxFSMgmoqEoYbHRpPcx3JEfv8VRsQe9Z4mCXeJBzxs7mbHY/XOZZuXlRNfhpVPbs6ZA==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/abort-controller": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/abort-controller/-/abort-controller-3.0.0.tgz", "integrity": "sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==", "license": "MIT", - "peer": true, "dependencies": { "event-target-shim": "^5.0.0" }, @@ -465,7 +453,6 @@ "resolved": "https://registry.npmjs.org/arrify/-/arrify-2.0.1.tgz", "integrity": "sha512-3duEwti880xqi4eAMN8AyR4a0ByT90zoYdLlevfrvU43vb0YZwZVfxOgxWrLXXXpyugL0hNZc9G6BiB5B3nUug==", "license": "MIT", - "peer": true, "engines": { "node": ">=8" } @@ -475,7 +462,6 @@ "resolved": "https://registry.npmjs.org/async-retry/-/async-retry-1.3.3.tgz", "integrity": "sha512-wfr/jstw9xNi/0teMHrRW7dsz3Lt5ARhYNZ2ewpadnhaIp5mbALhOAP+EAdsC7t4Z6wqsDVv9+W6gm1Dk9mEyw==", "license": "MIT", - "peer": true, "dependencies": { "retry": "0.13.1" } @@ -768,7 +754,6 @@ "resolved": "https://registry.npmjs.org/duplexify/-/duplexify-4.1.3.tgz", "integrity": "sha512-M3BmBhwJRZsSx38lZyhE53Csddgzl5R7xGJNk7CVddZD6CcmwMCH8J+7AprIrQKH7TonKxaCjcv27Qmf+sQ+oA==", "license": "MIT", - "peer": true, "dependencies": { "end-of-stream": "^1.4.1", "inherits": "^2.0.3", @@ -817,7 +802,6 @@ "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.5.tgz", "integrity": "sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==", "license": "MIT", - "peer": true, "dependencies": { "once": "^1.4.0" } @@ -887,7 +871,6 @@ "resolved": "https://registry.npmjs.org/event-target-shim/-/event-target-shim-5.0.1.tgz", "integrity": "sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==", "license": "MIT", - "peer": true, "engines": { "node": ">=6" } @@ -918,6 +901,7 @@ "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz", "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==", "license": "MIT", + "peer": true, "dependencies": { "accepts": "^2.0.0", "body-parser": "^2.2.0", @@ -999,7 +983,6 @@ } ], "license": "MIT", - "peer": true, "dependencies": { "strnum": "^1.1.1" }, @@ -1350,8 +1333,7 @@ "url": "https://patreon.com/mdevils" } ], - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/http-errors": { "version": "2.0.0", @@ -1383,7 +1365,6 @@ "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-5.0.0.tgz", "integrity": "sha512-n2hY8YdoRE1i7r6M0w9DIw5GgZN0G25P8zLCRQ8rjXtTU3vsNFBI/vWK/UIeE6g5MUUz6avwAPXmL6Fy9D/90w==", "license": "MIT", - "peer": true, "dependencies": { "@tootallnate/once": "2", "agent-base": "6", @@ -1398,7 +1379,6 @@ "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", "integrity": "sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ==", "license": "MIT", - "peer": true, "dependencies": { "debug": "4" }, @@ -1525,12 +1505,12 @@ } }, "node_modules/jws": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/jws/-/jws-4.0.0.tgz", - "integrity": "sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==", + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/jws/-/jws-4.0.1.tgz", + "integrity": "sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA==", "license": "MIT", "dependencies": { - "jwa": "^2.0.0", + "jwa": "^2.0.1", "safe-buffer": "^5.0.1" } }, @@ -1575,7 +1555,6 @@ "resolved": "https://registry.npmjs.org/mime/-/mime-3.0.0.tgz", "integrity": "sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A==", "license": "MIT", - "peer": true, "bin": { "mime": "cli.js" }, @@ -1736,7 +1715,6 @@ "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", "integrity": "sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==", "license": "MIT", - "peer": true, "dependencies": { "yocto-queue": "^0.1.0" }, @@ -1835,9 +1813,9 @@ } }, "node_modules/qs": { - "version": "6.14.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz", - "integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==", + "version": "6.14.1", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", + "integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==", "license": "BSD-3-Clause", "dependencies": { "side-channel": "^1.1.0" @@ -1878,7 +1856,6 @@ "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.2.tgz", "integrity": "sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==", "license": "MIT", - "peer": true, "dependencies": { "inherits": "^2.0.3", "string_decoder": "^1.1.1", @@ -1893,7 +1870,6 @@ "resolved": "https://registry.npmjs.org/retry/-/retry-0.13.1.tgz", "integrity": "sha512-XQBQ3I8W1Cge0Seh+6gjj03LbmRFWuoszgK9ooCpwYIrhhoO80pfq4cUkU5DkknwfOfFteRwlZ56PYOGYyFWdg==", "license": "MIT", - "peer": true, "engines": { "node": ">= 4" } @@ -1903,7 +1879,6 @@ "resolved": "https://registry.npmjs.org/retry-request/-/retry-request-7.0.2.tgz", "integrity": "sha512-dUOvLMJ0/JJYEn8NrpOaGNE7X3vpI5XlZS/u0ANjqtcZVKnIxP7IgCFwrKTxENw29emmwug53awKtaMm4i9g5w==", "license": "MIT", - "peer": true, "dependencies": { "@types/request": "^2.48.8", "extend": "^3.0.2", @@ -2132,7 +2107,6 @@ "resolved": "https://registry.npmjs.org/stream-events/-/stream-events-1.0.5.tgz", "integrity": "sha512-E1GUzBSgvct8Jsb3v2X15pjzN1tYebtbLaMg+eBOUOAxgbLoSbT2NS91ckc5lJD1KfLjId+jXJRgo0qnV5Nerg==", "license": "MIT", - "peer": true, "dependencies": { "stubs": "^3.0.0" } @@ -2141,15 +2115,13 @@ "version": "1.0.3", "resolved": "https://registry.npmjs.org/stream-shift/-/stream-shift-1.0.3.tgz", "integrity": "sha512-76ORR0DO1o1hlKwTbi/DM3EXWGf3ZJYO8cXX5RJwnul2DEg2oyoZyjLNoQM8WsvZiFKCRfC1O0J7iCvie3RZmQ==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/string_decoder": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", "license": "MIT", - "peer": true, "dependencies": { "safe-buffer": "~5.2.0" } @@ -2260,22 +2232,19 @@ "url": "https://github.com/sponsors/NaturalIntelligence" } ], - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/stubs": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/stubs/-/stubs-3.0.0.tgz", "integrity": "sha512-PdHt7hHUJKxvTCgbKX9C1V/ftOcjJQgz8BZwNfV5c4B6dcGqlpelTbJ999jBGZ2jYiPAwcX5dP6oBwVlBlUbxw==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/teeny-request": { "version": "9.0.0", "resolved": "https://registry.npmjs.org/teeny-request/-/teeny-request-9.0.0.tgz", "integrity": "sha512-resvxdc6Mgb7YEThw6G6bExlXKkv6+YbuzGg9xuXxSgxJF7Ozs+o8Y9+2R3sArdWdW8nOokoQb1yrpFB0pQK2g==", "license": "Apache-2.0", - "peer": true, "dependencies": { "http-proxy-agent": "^5.0.0", "https-proxy-agent": "^5.0.0", @@ -2292,7 +2261,6 @@ "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz", "integrity": "sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ==", "license": "MIT", - "peer": true, "dependencies": { "debug": "4" }, @@ -2305,7 +2273,6 @@ "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-5.0.1.tgz", "integrity": "sha512-dFcAjpTQFgoLMzC2VwU+C/CbS7uRL0lWmxDITmqm7C+7F0Odmj6s9l6alZc6AELXhrnggM2CeWSXHGOdX2YtwA==", "license": "MIT", - "peer": true, "dependencies": { "agent-base": "6", "debug": "4" @@ -2347,8 +2314,7 @@ "version": "7.16.0", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.16.0.tgz", "integrity": "sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/unpipe": { "version": "1.0.0", @@ -2372,8 +2338,7 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", - "license": "MIT", - "peer": true + "license": "MIT" }, "node_modules/uuid": { "version": "9.0.1", @@ -2560,7 +2525,6 @@ "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==", "license": "MIT", - "peer": true, "engines": { "node": ">=10" }, @@ -2573,6 +2537,7 @@ "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } diff --git a/docs/en/getting-started/quickstart/js/langchain/package-lock.json b/docs/en/getting-started/quickstart/js/langchain/package-lock.json index 7c52d6e598..a52001ef13 100644 --- a/docs/en/getting-started/quickstart/js/langchain/package-lock.json +++ b/docs/en/getting-started/quickstart/js/langchain/package-lock.json @@ -45,9 +45,9 @@ } }, "node_modules/@langchain/core": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/@langchain/core/-/core-1.1.0.tgz", - "integrity": "sha512-yJ6JHcU9psjnQbzRFkXjIdNTA+3074dA+2pHdH8ewvQCSleSk6JcjkCMIb5+NASjeMoi1ZuntlLKVsNqF38YxA==", + "version": "1.1.8", + "resolved": "https://registry.npmjs.org/@langchain/core/-/core-1.1.8.tgz", + "integrity": "sha512-kIUidOgc0ZdyXo4Ahn9Zas+OayqOfk4ZoKPi7XaDipNSWSApc2+QK5BVcjvwtzxstsNOrmXJiJWEN6WPF/MvAw==", "license": "MIT", "peer": true, "dependencies": { @@ -56,10 +56,9 @@ "camelcase": "6", "decamelize": "1.2.0", "js-tiktoken": "^1.0.12", - "langsmith": "^0.3.64", + "langsmith": ">=0.4.0 <1.0.0", "mustache": "^4.2.0", "p-queue": "^6.6.2", - "p-retry": "^7.0.0", "uuid": "^10.0.0", "zod": "^3.25.76 || ^4" }, @@ -67,25 +66,10 @@ "node": ">=20" } }, - "node_modules/@langchain/core/node_modules/p-retry": { - "version": "7.1.0", - "resolved": "https://registry.npmjs.org/p-retry/-/p-retry-7.1.0.tgz", - "integrity": "sha512-xL4PiFRQa/f9L9ZvR4/gUCRNus4N8YX80ku8kv9Jqz+ZokkiZLM0bcvX0gm1F3PDi9SPRsww1BDsTWgE6Y1GLQ==", - "license": "MIT", - "dependencies": { - "is-network-error": "^1.1.0" - }, - "engines": { - "node": ">=20" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/@langchain/google-genai": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/@langchain/google-genai/-/google-genai-2.0.0.tgz", - "integrity": "sha512-PaAWkogQdF+Y2bhhXWXUrC2nO7sTgWLtobBbZl/0V8Aa1F/KG2wrMECie3S17bAdFu/6VmQOuFFrlgSMwQC5KA==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/@langchain/google-genai/-/google-genai-2.1.3.tgz", + "integrity": "sha512-ZdlFK/N10GyU6ATzkM01Sk1rlHBoy36Q/MawGD1SyXdD2lQxZxuQZjFWewj6uzWQ2Nnjj70EvU/kmmHVPn6sfQ==", "license": "MIT", "dependencies": { "@google/generative-ai": "^0.24.0", @@ -95,7 +79,7 @@ "node": ">=20" }, "peerDependencies": { - "@langchain/core": "1.1.0" + "@langchain/core": "1.1.8" } }, "node_modules/@langchain/google-genai/node_modules/uuid": { @@ -814,18 +798,6 @@ "node": ">=8" } }, - "node_modules/is-network-error": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/is-network-error/-/is-network-error-1.3.0.tgz", - "integrity": "sha512-6oIwpsgRfnDiyEDLMay/GqCl3HoAtH5+RUKW29gYkL0QA+ipzpDLA16yQs7/RHCSu+BwgbJaOUqa4A99qNVQVw==", - "license": "MIT", - "engines": { - "node": ">=16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", @@ -882,13 +854,14 @@ } }, "node_modules/langchain": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/langchain/-/langchain-1.0.2.tgz", - "integrity": "sha512-He/xvjVl8DHESvdaW6Dpyba72OaLCAfS2CyOm1aWrlJ4C38dKXyTIxphtld8hiii6MWX7qMSmu2EaUwWBx2STg==", + "version": "1.2.3", + "resolved": "https://registry.npmjs.org/langchain/-/langchain-1.2.3.tgz", + "integrity": "sha512-3k986xJuqg4az53JxV5LnGlOzIXF1d9Kq6Y9s7XjitvzhpsbFuTDV5/kiF4cx3pkNGyw0mUXC4tLz9RxucO0hw==", + "license": "MIT", "dependencies": { "@langchain/langgraph": "^1.0.0", "@langchain/langgraph-checkpoint": "^1.0.0", - "langsmith": "~0.3.74", + "langsmith": ">=0.4.0 <1.0.0", "uuid": "^10.0.0", "zod": "^3.25.76 || ^4" }, @@ -896,19 +869,19 @@ "node": ">=20" }, "peerDependencies": { - "@langchain/core": "^1.0.0" + "@langchain/core": "1.1.8" } }, "node_modules/langsmith": { - "version": "0.3.77", - "resolved": "https://registry.npmjs.org/langsmith/-/langsmith-0.3.77.tgz", - "integrity": "sha512-wbS/9IX/hOAsOEOtPj8kCS8H0tFHaelwQ97gTONRtIfoPPLd9MMUmhk0KQB5DdsGAI5abg966+f0dZ/B+YRRzg==", + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/langsmith/-/langsmith-0.4.3.tgz", + "integrity": "sha512-vuBAagBZulXj0rpZhUTxmHhrYIBk53z8e2Q8ty4OHVkahN4ul7Im3OZxD9jsXZB0EuncK1xRYtY8J3BW4vj1zw==", + "license": "MIT", "dependencies": { "@types/uuid": "^10.0.0", "chalk": "^4.1.2", "console-table-printer": "^2.12.1", "p-queue": "^6.6.2", - "p-retry": "4", "semver": "^7.6.3", "uuid": "^10.0.0" }, diff --git a/docs/en/getting-started/quickstart/python/adk/requirements.txt b/docs/en/getting-started/quickstart/python/adk/requirements.txt index c4335ec83b..7fb84e8d67 100644 --- a/docs/en/getting-started/quickstart/python/adk/requirements.txt +++ b/docs/en/getting-started/quickstart/python/adk/requirements.txt @@ -1,3 +1,3 @@ -google-adk==1.19.0 -toolbox-core==0.5.3 -pytest==9.0.1 \ No newline at end of file +google-adk==1.21.0 +toolbox-core==0.5.4 +pytest==9.0.2 \ No newline at end of file diff --git a/docs/en/getting-started/quickstart/python/core/requirements.txt b/docs/en/getting-started/quickstart/python/core/requirements.txt index 62487afe19..2043b32be5 100644 --- a/docs/en/getting-started/quickstart/python/core/requirements.txt +++ b/docs/en/getting-started/quickstart/python/core/requirements.txt @@ -1,3 +1,3 @@ -google-genai==1.52.0 -toolbox-core==0.5.3 -pytest==9.0.1 +google-genai==1.56.0 +toolbox-core==0.5.4 +pytest==9.0.2 diff --git a/docs/en/getting-started/quickstart/python/langchain/requirements.txt b/docs/en/getting-started/quickstart/python/langchain/requirements.txt index e5d970bcb9..4090af465d 100644 --- a/docs/en/getting-started/quickstart/python/langchain/requirements.txt +++ b/docs/en/getting-started/quickstart/python/langchain/requirements.txt @@ -1,5 +1,5 @@ -langchain==1.1.0 -langchain-google-vertexai==3.1.0 -langgraph==1.0.4 -toolbox-langchain==0.5.3 -pytest==9.0.1 +langchain==1.2.0 +langchain-google-vertexai==3.2.0 +langgraph==1.0.5 +toolbox-langchain==0.5.4 +pytest==9.0.2 diff --git a/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt b/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt index c065d5dad7..bbdcc00f4c 100644 --- a/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt +++ b/docs/en/getting-started/quickstart/python/llamaindex/requirements.txt @@ -1,4 +1,4 @@ -llama-index==0.14.10 -llama-index-llms-google-genai==0.7.3 -toolbox-llamaindex==0.5.3 -pytest==9.0.1 +llama-index==0.14.12 +llama-index-llms-google-genai==0.8.3 +toolbox-llamaindex==0.5.4 +pytest==9.0.2 diff --git a/docs/en/resources/sources/dataplex.md b/docs/en/resources/sources/dataplex.md index 24539b3837..828ee5b698 100644 --- a/docs/en/resources/sources/dataplex.md +++ b/docs/en/resources/sources/dataplex.md @@ -229,22 +229,38 @@ Finds resources that were created within, before, or after a given date or time. ### Aspect Search To search for entries based on their attached aspects, use the following query syntax. -aspect:x Matches x as a substring of the full path to the aspect type of an aspect that is attached to the entry, in the format projectid.location.ASPECT_TYPE_ID -aspect=x Matches x as the full path to the aspect type of an aspect that is attached to the entry, in the format projectid.location.ASPECT_TYPE_ID -aspect:xOPERATORvalue -Searches for aspect field values. Matches x as a substring of the full path to the aspect type and field name of an aspect that is attached to the entry, in the format projectid.location.ASPECT_TYPE_ID.FIELD_NAME +`has:x` +Matches `x` as a substring of the full path to the aspect type of an aspect that is attached to the entry, in the format `projectid.location.ASPECT_TYPE_ID` -The list of supported {OPERATOR}s depends on the type of field in the aspect, as follows: -- String: = (exact match) and : (substring) -- All number types: =, :, <, >, <=, >=, =>, =< -- Enum: = -- Datetime: same as for numbers, but the values to compare are treated as datetimes instead of numbers -- Boolean: = +`has=x` +Matches `x` as the full path to the aspect type of an aspect that is attached to the entry, in the format `projectid.location.ASPECT_TYPE_ID` -Only top-level fields of the aspect are searchable. For example, all of the following queries match entries where the value of the is-enrolled field in the employee-info aspect type is true. Other entries that match on the substring are also returned. -- aspect:example-project.us-central1.employee-info.is-enrolled=true -- aspect:example-project.us-central1.employee=true -- aspect:employee=true +`xOPERATORvalue` +Searches for aspect field values. Matches x as a substring of the full path to the aspect type and field name of an aspect that is attached to the entry, in the format `projectid.location.ASPECT_TYPE_ID.FIELD_NAME` + +The list of supported operators depends on the type of field in the aspect, as follows: +* **String**: `=` (exact match) +* **All number types**: `=`, `:`, `<`, `>`, `<=`, `>=`, `=>`, `=<` +* **Enum**: `=` (exact match only) +* **Datetime**: same as for numbers, but the values to compare are treated as datetimes instead of numbers +* **Boolean**: `=` + +Only top-level fields of the aspect are searchable. + +* Syntax for system aspect types: + * `ASPECT_TYPE_ID.FIELD_NAME` + * `dataplex-types.ASPECT_TYPE_ID.FIELD_NAME` + * `dataplex-types.LOCATION.ASPECT_TYPE_ID.FIELD_NAME` +For example, the following queries match entries where the value of the `type` field in the `bigquery-dataset` aspect is `default`: + * `bigquery-dataset.type=default` + * `dataplex-types.bigquery-dataset.type=default` + * `dataplex-types.global.bigquery-dataset.type=default` +* Syntax for custom aspect types: + * If the aspect is created in the global region: `PROJECT_ID.ASPECT_TYPE_ID.FIELD_NAME` + * If the aspect is created in a specific region: `PROJECT_ID.REGION.ASPECT_TYPE_ID.FIELD_NAME` +For example, the following queries match entries where the value of the `is-enrolled` field in the `employee-info` aspect is `true`. + * `example-project.us-central1.employee-info.is-enrolled=true` + * `example-project.employee-info.is-enrolled=true` Example:- You can use following filters @@ -258,6 +274,25 @@ Logical AND and logical OR are supported. For example, foo OR bar. You can negate a predicate with a - (hyphen) or NOT prefix. For example, -name:foo returns resources with names that don't match the predicate foo. Logical operators are case-sensitive. `OR` and `AND` are acceptable whereas `or` and `and` are not. +### Abbreviated syntax + +An abbreviated search syntax is also available, using `|` (vertical bar) for `OR` operators and `,` (comma) for `AND` operators. + +For example, to search for entries inside one of many projects using the `OR` operator, you can use the following abbreviated syntax: + +`projectid:(id1|id2|id3|id4)` + +The same search without using abbreviated syntax looks like the following: + +`projectid:id1 OR projectid:id2 OR projectid:id3 OR projectid:id4` + +To search for entries with matching column names, use the following: + +* **AND**: `column:(name1,name2,name3)` +* **OR**: `column:(name1|name2|name3)` + +This abbreviated syntax works for the qualified predicates except for `label` in keyword search. + ### Request 1. Always try to rewrite the prompt using search syntax. diff --git a/go.mod b/go.mod index e0ed921ac5..a2d598c060 100644 --- a/go.mod +++ b/go.mod @@ -37,7 +37,6 @@ require ( github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.6 - github.com/json-iterator/go v1.1.12 github.com/looker-open-source/sdk-codegen/go v0.25.21 github.com/microsoft/go-mssqldb v1.9.3 github.com/nakagami/firebirdsql v0.9.15 @@ -138,6 +137,7 @@ require ( github.com/jcmturner/goidentity/v6 v6.0.1 // indirect github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect github.com/jcmturner/rpc/v2 v2.0.3 // indirect + github.com/json-iterator/go v1.1.12 // indirect github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect diff --git a/internal/sources/alloydbpg/alloydb_pg.go b/internal/sources/alloydbpg/alloydb_pg.go index 3adef5a051..58d8600ebd 100644 --- a/internal/sources/alloydbpg/alloydb_pg.go +++ b/internal/sources/alloydbpg/alloydb_pg.go @@ -24,6 +24,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -104,22 +105,22 @@ func (s *Source) PostgresPool() *pgxpool.Pool { func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { results, err := s.Pool.Query(ctx, statement, params...) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, statement, params) + return nil, fmt.Errorf("unable to execute query: %w", err) } + defer results.Close() fields := results.FieldDescriptions() - var out []any for results.Next() { v, err := results.Values() if err != nil { return nil, fmt.Errorf("unable to parse row: %w", err) } - vMap := make(map[string]any) + row := orderedmap.Row{} for i, f := range fields { - vMap[f.Name] = v[i] + row.Add(f.Name, v[i]) } - out = append(out, vMap) + out = append(out, row) } // this will catch actual query execution errors if err := results.Err(); err != nil { diff --git a/internal/sources/bigquery/bigquery.go b/internal/sources/bigquery/bigquery.go index 3b2d823dc1..a0e170e144 100644 --- a/internal/sources/bigquery/bigquery.go +++ b/internal/sources/bigquery/bigquery.go @@ -17,7 +17,9 @@ package bigquery import ( "context" "fmt" + "math/big" "net/http" + "reflect" "strings" "sync" "time" @@ -26,13 +28,16 @@ import ( dataplexapi "cloud.google.com/go/dataplex/apiv1" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" "golang.org/x/oauth2" "golang.org/x/oauth2/google" bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/googleapi" "google.golang.org/api/impersonate" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -483,6 +488,131 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer } } +func (s *Source) RetrieveClientAndService(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) { + bqClient := s.BigQueryClient() + restService := s.BigQueryRestService() + + // Initialize new client if using user OAuth token + if s.UseClientAuthorization() { + tokenStr, err := accessToken.ParseBearerToken() + if err != nil { + return nil, nil, fmt.Errorf("error parsing access token: %w", err) + } + bqClient, restService, err = s.BigQueryClientCreator()(tokenStr, true) + if err != nil { + return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err) + } + } + return bqClient, restService, nil +} + +func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, statement, statementType string, params []bigqueryapi.QueryParameter, connProps []*bigqueryapi.ConnectionProperty) (any, error) { + query := bqClient.Query(statement) + query.Location = bqClient.Location + if params != nil { + query.Parameters = params + } + if connProps != nil { + query.ConnectionProperties = connProps + } + + // This block handles SELECT statements, which return a row set. + // We iterate through the results, convert each row into a map of + // column names to values, and return the collection of rows. + job, err := query.Run(ctx) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + it, err := job.Read(ctx) + if err != nil { + return nil, fmt.Errorf("unable to read query results: %w", err) + } + + var out []any + for { + var val []bigqueryapi.Value + err = it.Next(&val) + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("unable to iterate through query results: %w", err) + } + schema := it.Schema + row := orderedmap.Row{} + for i, field := range schema { + row.Add(field.Name, NormalizeValue(val[i])) + } + out = append(out, row) + } + // If the query returned any rows, return them directly. + if len(out) > 0 { + return out, nil + } + + // This handles the standard case for a SELECT query that successfully + // executes but returns zero rows. + if statementType == "SELECT" { + return "The query returned 0 rows.", nil + } + // This is the fallback for a successful query that doesn't return content. + // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. + // However, it is also possible that this was a query that was expected to return rows + // but returned none, a case that we cannot distinguish here. + return "Query executed successfully and returned no content.", nil +} + +// NormalizeValue converts BigQuery specific types to standard JSON-compatible types. +// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting +// them to decimal strings with up to 38 digits of precision, trimming trailing zeros. +// It recursively handles slices (arrays) and maps (structs) using reflection. +func NormalizeValue(v any) any { + if v == nil { + return nil + } + + // Handle *big.Rat specifically. + if rat, ok := v.(*big.Rat); ok { + // Convert big.Rat to a decimal string. + // Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC) + // and trim trailing zeros to match BigQuery's behavior. + s := rat.FloatString(38) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s + } + + // Use reflection for slices and maps to handle various underlying types. + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice, reflect.Array: + // Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior). + if rv.Type().Elem().Kind() == reflect.Uint8 { + return v + } + newSlice := make([]any, rv.Len()) + for i := 0; i < rv.Len(); i++ { + newSlice[i] = NormalizeValue(rv.Index(i).Interface()) + } + return newSlice + case reflect.Map: + // Ensure keys are strings to produce a JSON-compatible map. + if rv.Type().Key().Kind() != reflect.String { + return v + } + newMap := make(map[string]any, rv.Len()) + iter := rv.MapRange() + for iter.Next() { + newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface()) + } + return newMap + } + + return v +} + func initBigQueryConnection( ctx context.Context, tracer trace.Tracer, diff --git a/internal/sources/bigquery/bigquery_test.go b/internal/sources/bigquery/bigquery_test.go index 58970a2ddb..bca167ff57 100644 --- a/internal/sources/bigquery/bigquery_test.go +++ b/internal/sources/bigquery/bigquery_test.go @@ -15,6 +15,8 @@ package bigquery_test import ( + "math/big" + "reflect" "testing" yaml "github.com/goccy/go-yaml" @@ -195,3 +197,105 @@ func TestFailParseFromYaml(t *testing.T) { }) } } + +func TestNormalizeValue(t *testing.T) { + tests := []struct { + name string + input any + expected any + }{ + { + name: "big.Rat 1/3 (NUMERIC scale 9)", + input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333... + expected: "0.33333333333333333333333333333333333333", // FloatString(38) + }, + { + name: "big.Rat 19/2 (9.5)", + input: new(big.Rat).SetFrac64(19, 2), + expected: "9.5", + }, + { + name: "big.Rat 12341/10 (1234.1)", + input: new(big.Rat).SetFrac64(12341, 10), + expected: "1234.1", + }, + { + name: "big.Rat 10/1 (10)", + input: new(big.Rat).SetFrac64(10, 1), + expected: "10", + }, + { + name: "string", + input: "hello", + expected: "hello", + }, + { + name: "int", + input: 123, + expected: 123, + }, + { + name: "nested slice of big.Rat", + input: []any{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "nested map of big.Rat", + input: map[string]any{ + "val1": new(big.Rat).SetFrac64(19, 2), + "val2": new(big.Rat).SetFrac64(1, 2), + }, + expected: map[string]any{ + "val1": "9.5", + "val2": "0.5", + }, + }, + { + name: "complex nested structure", + input: map[string]any{ + "list": []any{ + map[string]any{ + "rat": new(big.Rat).SetFrac64(3, 2), + }, + }, + }, + expected: map[string]any{ + "list": []any{ + map[string]any{ + "rat": "1.5", + }, + }, + }, + }, + { + name: "slice of *big.Rat", + input: []*big.Rat{ + new(big.Rat).SetFrac64(19, 2), + new(big.Rat).SetFrac64(1, 4), + }, + expected: []any{"9.5", "0.25"}, + }, + { + name: "slice of strings", + input: []string{"a", "b"}, + expected: []any{"a", "b"}, + }, + { + name: "byte slice (BYTES)", + input: []byte("hello"), + expected: []byte("hello"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := bigquery.NormalizeValue(tt.input) + if !reflect.DeepEqual(got, tt.expected) { + t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/internal/sources/bigtable/bigtable.go b/internal/sources/bigtable/bigtable.go index 22a41ee441..22daf64c37 100644 --- a/internal/sources/bigtable/bigtable.go +++ b/internal/sources/bigtable/bigtable.go @@ -22,6 +22,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" "google.golang.org/api/option" ) @@ -88,6 +89,94 @@ func (s *Source) BigtableClient() *bigtable.Client { return s.Client } +func getBigtableType(paramType string) (bigtable.SQLType, error) { + switch paramType { + case "boolean": + return bigtable.BoolSQLType{}, nil + case "string": + return bigtable.StringSQLType{}, nil + case "integer": + return bigtable.Int64SQLType{}, nil + case "float": + return bigtable.Float64SQLType{}, nil + case "array": + return bigtable.ArraySQLType{}, nil + default: + return nil, fmt.Errorf("unknow param type %s", paramType) + } +} + +func getMapParamsType(tparams parameters.Parameters) (map[string]bigtable.SQLType, error) { + btParamTypes := make(map[string]bigtable.SQLType) + for _, p := range tparams { + if p.GetType() == "array" { + itemType, err := getBigtableType(p.Manifest().Items.Type) + if err != nil { + return nil, err + } + btParamTypes[p.GetName()] = bigtable.ArraySQLType{ + ElemType: itemType, + } + continue + } + paramType, err := getBigtableType(p.GetType()) + if err != nil { + return nil, err + } + btParamTypes[p.GetName()] = paramType + } + return btParamTypes, nil +} + +func (s *Source) RunSQL(ctx context.Context, statement string, configParam parameters.Parameters, params parameters.ParamValues) (any, error) { + mapParamsType, err := getMapParamsType(configParam) + if err != nil { + return nil, fmt.Errorf("fail to get map params: %w", err) + } + + ps, err := s.BigtableClient().PrepareStatement( + ctx, + statement, + mapParamsType, + ) + if err != nil { + return nil, fmt.Errorf("unable to prepare statement: %w", err) + } + + bs, err := ps.Bind(params.AsMap()) + if err != nil { + return nil, fmt.Errorf("unable to bind: %w", err) + } + + var out []any + var rowErr error + err = bs.Execute(ctx, func(resultRow bigtable.ResultRow) bool { + vMap := make(map[string]any) + cols := resultRow.Metadata.Columns + + for _, c := range cols { + var columValue any + if err = resultRow.GetByName(c.Name, &columValue); err != nil { + rowErr = err + return false + } + vMap[c.Name] = columValue + } + + out = append(out, vMap) + + return true + }) + if err != nil { + return nil, fmt.Errorf("unable to execute client: %w", err) + } + if rowErr != nil { + return nil, fmt.Errorf("error processing row: %w", rowErr) + } + + return out, nil +} + func initBigtableClient(ctx context.Context, tracer trace.Tracer, name, project, instance string) (*bigtable.Client, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/cassandra/cassandra.go b/internal/sources/cassandra/cassandra.go index 9d2bf38d13..49c070bf06 100644 --- a/internal/sources/cassandra/cassandra.go +++ b/internal/sources/cassandra/cassandra.go @@ -21,6 +21,7 @@ import ( gocql "github.com/apache/cassandra-gocql-driver/v2" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -89,10 +90,32 @@ func (s *Source) ToConfig() sources.SourceConfig { } // SourceKind implements sources.Source. -func (s Source) SourceKind() string { +func (s *Source) SourceKind() string { return SourceKind } +func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) { + sliceParams := params.AsSlice() + iter := s.CassandraSession().Query(statement, sliceParams...).IterContext(ctx) + + // Create a slice to store the out + var out []map[string]interface{} + + // Scan results into a map and append to the slice + for { + row := make(map[string]interface{}) // Create a new map for each row + if !iter.MapScan(row) { + break // No more rows + } + out = append(out, row) + } + + if err := iter.Close(); err != nil { + return nil, fmt.Errorf("unable to parse rows: %w", err) + } + return out, nil +} + var _ sources.Source = &Source{} func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) { diff --git a/internal/sources/clickhouse/clickhouse.go b/internal/sources/clickhouse/clickhouse.go index 391d9bb639..3f0b6f961b 100644 --- a/internal/sources/clickhouse/clickhouse.go +++ b/internal/sources/clickhouse/clickhouse.go @@ -24,6 +24,7 @@ import ( _ "github.com/ClickHouse/clickhouse-go/v2" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -99,6 +100,69 @@ func (s *Source) ClickHousePool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params parameters.ParamValues) (any, error) { + var sliceParams []any + if params != nil { + sliceParams = params.AsSlice() + } + results, err := s.ClickHousePool().QueryContext(ctx, statement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + // ClickHouse driver may return specific types that need handling + switch colTypes[i].DatabaseTypeName() { + case "String", "FixedString": + if rawValues[i] != nil { + // Handle potential []byte to string conversion if needed + if b, ok := rawValues[i].([]byte); ok { + vMap[name] = string(b) + } else { + vMap[name] = rawValues[i] + } + } else { + vMap[name] = nil + } + default: + vMap[name] = rawValues[i] + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered by results.Scan: %w", err) + } + + return out, nil +} + func validateConfig(protocol string) error { validProtocols := map[string]bool{"http": true, "https": true} diff --git a/internal/sources/cloudgda/cloud_gda.go b/internal/sources/cloudgda/cloud_gda.go index a87ff11c59..5743991647 100644 --- a/internal/sources/cloudgda/cloud_gda.go +++ b/internal/sources/cloudgda/cloud_gda.go @@ -14,8 +14,11 @@ package cloudgda import ( + "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "github.com/goccy/go-yaml" @@ -131,3 +134,43 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } + +func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) { + // The API endpoint itself always uses the "global" location. + apiLocation := "global" + apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation) + apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent) + + client, err := s.GetClient(ctx, tokenStr) + if err != nil { + return nil, fmt.Errorf("failed to get HTTP client: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return result, nil +} diff --git a/internal/sources/cloudmonitoring/cloud_monitoring.go b/internal/sources/cloudmonitoring/cloud_monitoring.go index d43468687d..eb478dce24 100644 --- a/internal/sources/cloudmonitoring/cloud_monitoring.go +++ b/internal/sources/cloudmonitoring/cloud_monitoring.go @@ -15,7 +15,9 @@ package cloudmonitoring import ( "context" + "encoding/json" "fmt" + "io" "net/http" "github.com/goccy/go-yaml" @@ -131,3 +133,44 @@ func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Clien func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } + +func (s *Source) RunQuery(projectID, query string) (any, error) { + url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", s.BaseURL(), projectID) + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + q := req.URL.Query() + q.Add("query", query) + req.URL.RawQuery = q.Encode() + + req.Header.Set("User-Agent", s.UserAgent()) + + resp, err := s.Client().Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("request failed: %s, body: %s", resp.Status, string(body)) + } + + if len(body) == 0 { + return nil, nil + } + + var result map[string]any + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal json: %w, body: %s", err, string(body)) + } + + return result, nil +} diff --git a/internal/sources/cloudsqladmin/cloud_sql_admin.go b/internal/sources/cloudsqladmin/cloud_sql_admin.go index 3a3ff48caf..7d8929b782 100644 --- a/internal/sources/cloudsqladmin/cloud_sql_admin.go +++ b/internal/sources/cloudsqladmin/cloud_sql_admin.go @@ -15,10 +15,16 @@ package cloudsqladmin import ( "context" + "encoding/json" "fmt" "net/http" + "regexp" + "strings" + "text/template" + "time" "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" "go.opentelemetry.io/otel/trace" @@ -30,6 +36,8 @@ import ( const SourceKind string = "cloud-sql-admin" +var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`) + // validate interface var _ sources.SourceConfig = Config{} @@ -130,3 +138,304 @@ func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin. func (s *Source) UseClientAuthorization() bool { return s.UseClientOAuth } + +func (s *Source) CloneInstance(ctx context.Context, project, sourceInstanceName, destinationInstanceName, pointInTime, preferredZone, preferredSecondaryZone, accessToken string) (any, error) { + cloneContext := &sqladmin.CloneContext{ + DestinationInstanceName: destinationInstanceName, + } + + if pointInTime != "" { + cloneContext.PointInTime = pointInTime + } + if preferredZone != "" { + cloneContext.PreferredZone = preferredZone + } + if preferredSecondaryZone != "" { + cloneContext.PreferredSecondaryZone = preferredSecondaryZone + } + + rb := &sqladmin.InstancesCloneRequest{ + CloneContext: cloneContext, + } + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + resp, err := service.Instances.Clone(project, sourceInstanceName, rb).Do() + if err != nil { + return nil, fmt.Errorf("error cloning instance: %w", err) + } + return resp, nil +} + +func (s *Source) CreateDatabase(ctx context.Context, name, project, instance, accessToken string) (any, error) { + database := sqladmin.Database{ + Name: name, + Project: project, + Instance: instance, + } + + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + resp, err := service.Databases.Insert(project, instance, &database).Do() + if err != nil { + return nil, fmt.Errorf("error creating database: %w", err) + } + return resp, nil +} + +func (s *Source) CreateUsers(ctx context.Context, project, instance, name, password string, iamUser bool, accessToken string) (any, error) { + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + user := sqladmin.User{ + Name: name, + } + + if iamUser { + user.Type = "CLOUD_IAM_USER" + } else { + user.Type = "BUILT_IN" + if password == "" { + return nil, fmt.Errorf("missing 'password' parameter for non-IAM user") + } + user.Password = password + } + + resp, err := service.Users.Insert(project, instance, &user).Do() + if err != nil { + return nil, fmt.Errorf("error creating user: %w", err) + } + + return resp, nil +} + +func (s *Source) GetInstance(ctx context.Context, projectId, instanceId, accessToken string) (any, error) { + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + resp, err := service.Instances.Get(projectId, instanceId).Do() + if err != nil { + return nil, fmt.Errorf("error getting instance: %w", err) + } + return resp, nil +} + +func (s *Source) ListDatabase(ctx context.Context, project, instance, accessToken string) (any, error) { + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + resp, err := service.Databases.List(project, instance).Do() + if err != nil { + return nil, fmt.Errorf("error listing databases: %w", err) + } + + if resp.Items == nil { + return []any{}, nil + } + + type databaseInfo struct { + Name string `json:"name"` + Charset string `json:"charset"` + Collation string `json:"collation"` + } + + var databases []databaseInfo + for _, item := range resp.Items { + databases = append(databases, databaseInfo{ + Name: item.Name, + Charset: item.Charset, + Collation: item.Collation, + }) + } + return databases, nil +} + +func (s *Source) ListInstance(ctx context.Context, project, accessToken string) (any, error) { + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + resp, err := service.Instances.List(project).Do() + if err != nil { + return nil, fmt.Errorf("error listing instances: %w", err) + } + + if resp.Items == nil { + return []any{}, nil + } + + type instanceInfo struct { + Name string `json:"name"` + InstanceType string `json:"instanceType"` + } + + var instances []instanceInfo + for _, item := range resp.Items { + instances = append(instances, instanceInfo{ + Name: item.Name, + InstanceType: item.InstanceType, + }) + } + return instances, nil +} + +func (s *Source) CreateInstance(ctx context.Context, project, name, dbVersion, rootPassword string, settings sqladmin.Settings, accessToken string) (any, error) { + instance := sqladmin.DatabaseInstance{ + Name: name, + DatabaseVersion: dbVersion, + RootPassword: rootPassword, + Settings: &settings, + Project: project, + } + + service, err := s.GetService(ctx, accessToken) + if err != nil { + return nil, err + } + + resp, err := service.Instances.Insert(project, &instance).Do() + if err != nil { + return nil, fmt.Errorf("error creating instance: %w", err) + } + + return resp, nil +} + +func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Service, project, operation, connectionMessageTemplate string, delay time.Duration) (any, error) { + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, err + } + op, err := service.Operations.Get(project, operation).Do() + if err != nil { + logger.DebugContext(ctx, fmt.Sprintf("error getting operation: %s, retrying in %v", err, delay)) + } else { + if op.Status == "DONE" { + if op.Error != nil { + var errorBytes []byte + errorBytes, err = json.Marshal(op.Error) + if err != nil { + return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err) + } + return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes)) + } + + var opBytes []byte + opBytes, err = op.MarshalJSON() + if err != nil { + return nil, fmt.Errorf("could not marshal operation: %w", err) + } + + var data map[string]any + if err := json.Unmarshal(opBytes, &data); err != nil { + return nil, fmt.Errorf("could not unmarshal operation: %w", err) + } + + if msg, ok := generateCloudSQLConnectionMessage(ctx, s, logger, data, connectionMessageTemplate); ok { + return msg, nil + } + return string(opBytes), nil + } + logger.DebugContext(ctx, fmt.Sprintf("operation not complete, retrying in %v", delay)) + } + return nil, nil +} + +func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) { + operationType, ok := opResponse["operationType"].(string) + if !ok || operationType != "CREATE_DATABASE" { + return "", false + } + + targetLink, ok := opResponse["targetLink"].(string) + if !ok { + return "", false + } + + matches := targetLinkRegex.FindStringSubmatch(targetLink) + if len(matches) < 4 { + return "", false + } + project := matches[1] + instance := matches[2] + database := matches[3] + + dbInstance, err := fetchInstanceData(ctx, source, project, instance) + if err != nil { + logger.DebugContext(ctx, fmt.Sprintf("error fetching instance data: %v", err)) + return "", false + } + + region := dbInstance.Region + if region == "" { + return "", false + } + + databaseVersion := dbInstance.DatabaseVersion + if databaseVersion == "" { + return "", false + } + + var dbType string + if strings.Contains(databaseVersion, "POSTGRES") { + dbType = "postgres" + } else if strings.Contains(databaseVersion, "MYSQL") { + dbType = "mysql" + } else if strings.Contains(databaseVersion, "SQLSERVER") { + dbType = "mssql" + } else { + return "", false + } + + tmpl, err := template.New("cloud-sql-connection").Parse(connectionMessageTemplate) + if err != nil { + return fmt.Sprintf("template parsing error: %v", err), false + } + + data := struct { + Project string + Region string + Instance string + DBType string + DBTypeUpper string + Database string + }{ + Project: project, + Region: region, + Instance: instance, + DBType: dbType, + DBTypeUpper: strings.ToUpper(dbType), + Database: database, + } + + var b strings.Builder + if err := tmpl.Execute(&b, data); err != nil { + return fmt.Sprintf("template execution error: %v", err), false + } + + return b.String(), true +} + +func fetchInstanceData(ctx context.Context, source *Source, project, instance string) (*sqladmin.DatabaseInstance, error) { + service, err := source.GetService(ctx, "") + if err != nil { + return nil, err + } + + resp, err := service.Instances.Get(project, instance).Do() + if err != nil { + return nil, fmt.Errorf("error getting instance: %w", err) + } + return resp, nil +} diff --git a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go index 1435165fde..02480df326 100644 --- a/internal/sources/cloudsqlmssql/cloud_sql_mssql.go +++ b/internal/sources/cloudsqlmssql/cloud_sql_mssql.go @@ -25,6 +25,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -107,6 +108,48 @@ func (s *Source) MSSQLDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. + // We proceed, and results.Err() will catch actual query execution errors. + // 'out' will remain nil if cols is empty or err is not nil here. + var out []any + if err == nil && len(cols) > 0 { + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + for results.Next() { + scanErr := results.Scan(values...) + if scanErr != nil { + return nil, fmt.Errorf("unable to parse row: %w", scanErr) + } + row := orderedmap.Row{} + for i, name := range cols { + row.Add(name, rawValues[i]) + } + out = append(out, row) + } + } + + // Check for errors from iterating over rows or from the query execution itself. + // results.Close() is handled by defer. + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) + } + + return out, nil +} + func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipType, user, pass, dbname string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go index 797985454b..759f00af7d 100644 --- a/internal/sources/cloudsqlmysql/cloud_sql_mysql.go +++ b/internal/sources/cloudsqlmysql/cloud_sql_mysql.go @@ -24,7 +24,9 @@ import ( "cloud.google.com/go/cloudsqlconn/mysql/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MySQLPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, name := range cols { + val := rawValues[i] + if val == nil { + row.Add(name, nil) + continue + } + + convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + row.Add(name, convertedValue) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func getConnectionConfig(ctx context.Context, user, pass string) (string, string, bool, error) { useIAM := true diff --git a/internal/sources/cloudsqlpg/cloud_sql_pg.go b/internal/sources/cloudsqlpg/cloud_sql_pg.go index 3de83993bb..dc7e59be3d 100644 --- a/internal/sources/cloudsqlpg/cloud_sql_pg.go +++ b/internal/sources/cloudsqlpg/cloud_sql_pg.go @@ -23,6 +23,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -99,6 +100,33 @@ func (s *Source) PostgresPool() *pgxpool.Pool { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.PostgresPool().Query(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []any + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, values[i]) + } + out = append(out, row) + } + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + return out, nil +} + func getConnectionConfig(ctx context.Context, user, pass, dbname string) (string, bool, error) { userAgent, err := util.UserAgentFromContext(ctx) if err != nil { diff --git a/internal/sources/couchbase/couchbase.go b/internal/sources/couchbase/couchbase.go index 422d9ab001..c273a47ec0 100644 --- a/internal/sources/couchbase/couchbase.go +++ b/internal/sources/couchbase/couchbase.go @@ -17,6 +17,7 @@ package couchbase import ( "context" "crypto/tls" + "encoding/json" "fmt" "os" @@ -24,6 +25,7 @@ import ( tlsutil "github.com/couchbase/tools-common/http/tls" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -110,6 +112,27 @@ func (s *Source) CouchbaseQueryScanConsistency() uint { return s.QueryScanConsistency } +func (s *Source) RunSQL(statement string, params parameters.ParamValues) (any, error) { + results, err := s.CouchbaseScope().Query(statement, &gocb.QueryOptions{ + ScanConsistency: gocb.QueryScanConsistency(s.CouchbaseQueryScanConsistency()), + NamedParameters: params.AsMap(), + }) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + var out []any + for results.Next() { + var result json.RawMessage + err := results.Row(&result) + if err != nil { + return nil, fmt.Errorf("error processing row: %w", err) + } + out = append(out, result) + } + return out, nil +} + func (r Config) createCouchbaseOptions() (gocb.ClusterOptions, error) { cbOpts := gocb.ClusterOptions{} diff --git a/internal/sources/dgraph/dgraph.go b/internal/sources/dgraph/dgraph.go index 24f8f8b20e..317779db38 100644 --- a/internal/sources/dgraph/dgraph.go +++ b/internal/sources/dgraph/dgraph.go @@ -26,6 +26,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/parameters" "go.opentelemetry.io/otel/trace" ) @@ -114,6 +115,28 @@ func (s *Source) DgraphClient() *DgraphClient { return s.Client } +func (s *Source) RunSQL(statement string, params parameters.ParamValues, isQuery bool, timeout string) (any, error) { + paramsMap := params.AsMapWithDollarPrefix() + resp, err := s.DgraphClient().ExecuteQuery(statement, paramsMap, isQuery, timeout) + if err != nil { + return nil, err + } + + if err := checkError(resp); err != nil { + return nil, err + } + + var result struct { + Data map[string]interface{} `json:"data"` + } + + if err := json.Unmarshal(resp, &result); err != nil { + return nil, fmt.Errorf("error parsing JSON: %v", err) + } + + return result.Data, nil +} + func initDgraphHttpClient(ctx context.Context, tracer trace.Tracer, r Config) (*DgraphClient, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, r.Name) @@ -285,7 +308,7 @@ func (hc *DgraphClient) doLogin(creds map[string]interface{}) error { return err } - if err := CheckError(resp); err != nil { + if err := checkError(resp); err != nil { return err } @@ -370,7 +393,7 @@ func getUrl(baseUrl, resource string, params url.Values) (string, error) { return u.String(), nil } -func CheckError(resp []byte) error { +func checkError(resp []byte) error { var errResp struct { Errors []struct { Message string `json:"message"` diff --git a/internal/sources/elasticsearch/elasticsearch.go b/internal/sources/elasticsearch/elasticsearch.go index 2d7b788407..b5ec915c18 100644 --- a/internal/sources/elasticsearch/elasticsearch.go +++ b/internal/sources/elasticsearch/elasticsearch.go @@ -15,7 +15,9 @@ package elasticsearch import ( + "bytes" "context" + "encoding/json" "fmt" "net/http" @@ -149,3 +151,80 @@ func (s *Source) ToConfig() sources.SourceConfig { func (s *Source) ElasticsearchClient() EsClient { return s.Client } + +type EsqlColumn struct { + Name string `json:"name"` + Type string `json:"type"` +} + +type EsqlResult struct { + Columns []EsqlColumn `json:"columns"` + Values [][]any `json:"values"` +} + +func (s *Source) RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error) { + bodyStruct := struct { + Query string `json:"query"` + Params []map[string]any `json:"params,omitempty"` + }{ + Query: query, + Params: params, + } + body, err := json.Marshal(bodyStruct) + if err != nil { + return nil, fmt.Errorf("failed to marshal query body: %w", err) + } + + res, err := esapi.EsqlQueryRequest{ + Body: bytes.NewReader(body), + Format: format, + FilterPath: []string{"columns", "values"}, + Instrument: s.ElasticsearchClient().InstrumentationEnabled(), + }.Do(ctx, s.ElasticsearchClient()) + + if err != nil { + return nil, err + } + defer res.Body.Close() + + if res.IsError() { + // Try to extract error message from response + var esErr json.RawMessage + err = util.DecodeJSON(res.Body, &esErr) + if err != nil { + return nil, fmt.Errorf("elasticsearch error: status %s", res.Status()) + } + return esErr, nil + } + + var result EsqlResult + err = util.DecodeJSON(res.Body, &result) + if err != nil { + return nil, fmt.Errorf("failed to decode response body: %w", err) + } + + output := EsqlToMap(result) + + return output, nil +} + +// EsqlToMap converts the esqlResult to a slice of maps. +func EsqlToMap(result EsqlResult) []map[string]any { + output := make([]map[string]any, 0, len(result.Values)) + for _, value := range result.Values { + row := make(map[string]any) + if value == nil { + output = append(output, row) + continue + } + for i, col := range result.Columns { + if i < len(value) { + row[col.Name] = value[i] + } else { + row[col.Name] = nil + } + } + output = append(output, row) + } + return output +} diff --git a/internal/sources/elasticsearch/elasticsearch_test.go b/internal/sources/elasticsearch/elasticsearch_test.go index 6ea9d33dce..95d941edc4 100644 --- a/internal/sources/elasticsearch/elasticsearch_test.go +++ b/internal/sources/elasticsearch/elasticsearch_test.go @@ -15,6 +15,7 @@ package elasticsearch_test import ( + "reflect" "testing" yaml "github.com/goccy/go-yaml" @@ -64,3 +65,155 @@ func TestParseFromYamlElasticsearch(t *testing.T) { }) } } + +func TestTool_esqlToMap(t1 *testing.T) { + tests := []struct { + name string + result elasticsearch.EsqlResult + want []map[string]any + }{ + { + name: "simple case with two rows", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "first_name", Type: "text"}, + {Name: "last_name", Type: "text"}, + }, + Values: [][]any{ + {"John", "Doe"}, + {"Jane", "Smith"}, + }, + }, + want: []map[string]any{ + {"first_name": "John", "last_name": "Doe"}, + {"first_name": "Jane", "last_name": "Smith"}, + }, + }, + { + name: "different data types", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "active", Type: "boolean"}, + {Name: "score", Type: "float"}, + }, + Values: [][]any{ + {1, true, 95.5}, + {2, false, 88.0}, + }, + }, + want: []map[string]any{ + {"id": 1, "active": true, "score": 95.5}, + {"id": 2, "active": false, "score": 88.0}, + }, + }, + { + name: "no rows", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{}, + }, + want: []map[string]any{}, + }, + { + name: "null values", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{ + {1, nil}, + {2, "Alice"}, + }, + }, + want: []map[string]any{ + {"id": 1, "name": nil}, + {"id": 2, "name": "Alice"}, + }, + }, + { + name: "missing values in a row", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + }, + Values: [][]any{ + {1, "Bob"}, + {2, "Charlie", 30}, + }, + }, + want: []map[string]any{ + {"id": 1, "name": "Bob", "age": nil}, + {"id": 2, "name": "Charlie", "age": 30}, + }, + }, + { + name: "all null row", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }, + Values: [][]any{ + nil, + }, + }, + want: []map[string]any{ + {}, + }, + }, + { + name: "empty columns", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{}, + Values: [][]any{ + {}, + {}, + }, + }, + want: []map[string]any{ + {}, + {}, + }, + }, + { + name: "more values than columns", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{ + {Name: "id", Type: "integer"}, + }, + Values: [][]any{ + {1, "extra"}, + }, + }, + want: []map[string]any{ + {"id": 1}, + }, + }, + { + name: "no columns but with values", + result: elasticsearch.EsqlResult{ + Columns: []elasticsearch.EsqlColumn{}, + Values: [][]any{ + {1, "data"}, + }, + }, + want: []map[string]any{ + {}, + }, + }, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + if got := elasticsearch.EsqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) { + t1.Errorf("esqlToMap() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/sources/firebird/firebird.go b/internal/sources/firebird/firebird.go index 43775be70c..4be3d20cac 100644 --- a/internal/sources/firebird/firebird.go +++ b/internal/sources/firebird/firebird.go @@ -96,6 +96,53 @@ func (s *Source) FirebirdDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + rows, err := s.FirebirdDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("unable to get columns: %w", err) + } + + values := make([]any, len(cols)) + scanArgs := make([]any, len(values)) + for i := range values { + scanArgs[i] = &values[i] + } + + var out []any + for rows.Next() { + + err = rows.Scan(scanArgs...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + + vMap := make(map[string]any) + for i, col := range cols { + if b, ok := values[i].([]byte); ok { + vMap[col] = string(b) + } else { + vMap[col] = values[i] + } + } + out = append(out, vMap) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) + } + + // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows + // However, it is also possible that this was a query that was expected to return rows + // but returned none, a case that we cannot distinguish here. + return out, nil +} + func initFirebirdConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string) (*sql.DB, error) { _, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() diff --git a/internal/sources/mindsdb/mindsdb.go b/internal/sources/mindsdb/mindsdb.go index a4f56a7d8e..4bb5daac1c 100644 --- a/internal/sources/mindsdb/mindsdb.go +++ b/internal/sources/mindsdb/mindsdb.go @@ -23,6 +23,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "go.opentelemetry.io/otel/trace" ) @@ -101,6 +102,61 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + // MindsDB now supports MySQL prepared statements natively + results, err := s.MindsDBPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + defer results.Close() + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + val := rawValues[i] + if val == nil { + vMap[name] = nil + continue + } + + // MindsDB uses mysql driver + vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initMindsDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/mssql/mssql.go b/internal/sources/mssql/mssql.go index 39a37bf5a0..688ccf18c4 100644 --- a/internal/sources/mssql/mssql.go +++ b/internal/sources/mssql/mssql.go @@ -23,6 +23,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" _ "github.com/microsoft/go-mssqldb" "go.opentelemetry.io/otel/trace" ) @@ -104,6 +105,48 @@ func (s *Source) MSSQLDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MSSQLDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. + // We proceed, and results.Err() will catch actual query execution errors. + // 'out' will remain nil if cols is empty or err is not nil here. + var out []any + if err == nil && len(cols) > 0 { + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + for results.Next() { + scanErr := results.Scan(values...) + if scanErr != nil { + return nil, fmt.Errorf("unable to parse row: %w", scanErr) + } + row := orderedmap.Row{} + for i, name := range cols { + row.Add(name, rawValues[i]) + } + out = append(out, row) + } + } + + // Check for errors from iterating over rows or from the query execution itself. + // results.Close() is handled by defer. + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) + } + + return out, nil +} + func initMssqlConnection( ctx context.Context, tracer trace.Tracer, diff --git a/internal/sources/mysql/mysql.go b/internal/sources/mysql/mysql.go index 13f4b2a3d9..b456ec9a3f 100644 --- a/internal/sources/mysql/mysql.go +++ b/internal/sources/mysql/mysql.go @@ -24,7 +24,9 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" ) @@ -100,6 +102,60 @@ func (s *Source) MySQLPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.MySQLPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, name := range cols { + val := rawValues[i] + if val == nil { + row.Add(name, nil) + continue + } + + convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + row.Add(name, convertedValue) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initMySQLConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string, queryParams map[string]string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/neo4j/neo4j.go b/internal/sources/neo4j/neo4j.go index 7e25819035..70cc21ae14 100644 --- a/internal/sources/neo4j/neo4j.go +++ b/internal/sources/neo4j/neo4j.go @@ -20,14 +20,19 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier" + "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/googleapis/genai-toolbox/internal/util" "github.com/neo4j/neo4j-go-driver/v5/neo4j" neo4jconf "github.com/neo4j/neo4j-go-driver/v5/neo4j/config" + "go.opentelemetry.io/otel/trace" ) const SourceKind string = "neo4j" +var sourceClassifier *classifier.QueryClassifier = classifier.NewQueryClassifier() + // validate interface var _ sources.SourceConfig = Config{} @@ -102,6 +107,79 @@ func (s *Source) Neo4jDatabase() string { return s.Database } +func (s *Source) RunQuery(ctx context.Context, cypherStr string, params map[string]any, readOnly, dryRun bool) (any, error) { + // validate the cypher query before executing + cf := sourceClassifier.Classify(cypherStr) + if cf.Error != nil { + return nil, cf.Error + } + + if cf.Type == classifier.WriteQuery && readOnly { + return nil, fmt.Errorf("this tool is read-only and cannot execute write queries") + } + + if dryRun { + // Add EXPLAIN to the beginning of the query to validate it without executing + cypherStr = "EXPLAIN " + cypherStr + } + + config := neo4j.ExecuteQueryWithDatabase(s.Neo4jDatabase()) + results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, s.Neo4jDriver(), cypherStr, params, + neo4j.EagerResultTransformer, config) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + // If dry run, return the summary information only + if dryRun { + summary := results.Summary + plan := summary.Plan() + execPlan := map[string]any{ + "queryType": cf.Type.String(), + "statementType": summary.StatementType(), + "operator": plan.Operator(), + "arguments": plan.Arguments(), + "identifiers": plan.Identifiers(), + "childrenCount": len(plan.Children()), + } + if len(plan.Children()) > 0 { + execPlan["children"] = addPlanChildren(plan) + } + return []map[string]any{execPlan}, nil + } + + var out []map[string]any + keys := results.Keys + records := results.Records + for _, record := range records { + vMap := make(map[string]any) + for col, value := range record.Values { + vMap[keys[col]] = helpers.ConvertValue(value) + } + out = append(out, vMap) + } + + return out, nil +} + +// Recursive function to add plan children +func addPlanChildren(p neo4j.Plan) []map[string]any { + var children []map[string]any + for _, child := range p.Children() { + childMap := map[string]any{ + "operator": child.Operator(), + "arguments": child.Arguments(), + "identifiers": child.Identifiers(), + "children_count": len(child.Children()), + } + if len(child.Children()) > 0 { + childMap["children"] = addPlanChildren(child) + } + children = append(children, childMap) + } + return children +} + func initNeo4jDriver(ctx context.Context, tracer trace.Tracer, uri, user, password, name string) (neo4j.DriverWithContext, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/oceanbase/oceanbase.go b/internal/sources/oceanbase/oceanbase.go index 59aaf72ee5..27a989ae3d 100644 --- a/internal/sources/oceanbase/oceanbase.go +++ b/internal/sources/oceanbase/oceanbase.go @@ -23,6 +23,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "go.opentelemetry.io/otel/trace" ) @@ -97,6 +98,60 @@ func (s *Source) OceanBasePool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.OceanBasePool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + defer results.Close() + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + val := rawValues[i] + if val == nil { + vMap[name] = nil + continue + } + + // oceanbase uses mysql driver + vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initOceanBaseConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) { _, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) defer span.End() diff --git a/internal/sources/oracle/oracle.go b/internal/sources/oracle/oracle.go index 4de64b402b..29d78cc706 100644 --- a/internal/sources/oracle/oracle.go +++ b/internal/sources/oracle/oracle.go @@ -4,6 +4,7 @@ package oracle import ( "context" "database/sql" + "encoding/json" "fmt" "os" "strings" @@ -135,6 +136,107 @@ func (s *Source) OracleDB() *sql.DB { return s.DB } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + rows, err := s.OracleDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer rows.Close() + + // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. + // We proceed, and results.Err() will catch actual query execution errors. + // 'out' will remain nil if cols is empty or err is not nil here. + cols, _ := rows.Columns() + + // Get Column types + colTypes, err := rows.ColumnTypes() + if err != nil { + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("query execution error: %w", err) + } + return []any{}, nil + } + + var out []any + for rows.Next() { + values := make([]any, len(cols)) + for i, colType := range colTypes { + switch strings.ToUpper(colType.DatabaseTypeName()) { + case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE": + if _, scale, ok := colType.DecimalSize(); ok && scale == 0 { + // Scale is 0, treat it as an integer. + values[i] = new(sql.NullInt64) + } else { + // Scale is non-zero or unknown, treat + // it as a float. + values[i] = new(sql.NullFloat64) + } + case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE": + values[i] = new(sql.NullTime) + case "JSON": + values[i] = new(sql.RawBytes) + default: + values[i] = new(sql.NullString) + } + } + + if err := rows.Scan(values...); err != nil { + return nil, fmt.Errorf("unable to scan row: %w", err) + } + + vMap := make(map[string]any) + for i, col := range cols { + receiver := values[i] + + switch v := receiver.(type) { + case *sql.NullInt64: + if v.Valid { + vMap[col] = v.Int64 + } else { + vMap[col] = nil + } + case *sql.NullFloat64: + if v.Valid { + vMap[col] = v.Float64 + } else { + vMap[col] = nil + } + case *sql.NullString: + if v.Valid { + vMap[col] = v.String + } else { + vMap[col] = nil + } + case *sql.NullTime: + if v.Valid { + vMap[col] = v.Time + } else { + vMap[col] = nil + } + case *sql.RawBytes: + if *v != nil { + var unmarshaledData any + if err := json.Unmarshal(*v, &unmarshaledData); err != nil { + return nil, fmt.Errorf("unable to unmarshal json data for column %s", col) + } + vMap[col] = unmarshaledData + } else { + vMap[col] = nil + } + default: + return nil, fmt.Errorf("unexpected receiver type: %T", v) + } + } + out = append(out, vMap) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) + } + + return out, nil +} + func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Config) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, config.Name) diff --git a/internal/sources/postgres/postgres.go b/internal/sources/postgres/postgres.go index e3dfeb7c44..d23721fc06 100644 --- a/internal/sources/postgres/postgres.go +++ b/internal/sources/postgres/postgres.go @@ -23,6 +23,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/jackc/pgx/v5/pgxpool" "go.opentelemetry.io/otel/trace" ) @@ -98,6 +99,33 @@ func (s *Source) PostgresPool() *pgxpool.Pool { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.PostgresPool().Query(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []any + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, values[i]) + } + out = append(out, row) + } + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + return out, nil +} + func initPostgresConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string) (*pgxpool.Pool, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/redis/redis.go b/internal/sources/redis/redis.go index 7f38ce14f1..f8e4bfa40e 100644 --- a/internal/sources/redis/redis.go +++ b/internal/sources/redis/redis.go @@ -152,3 +152,50 @@ func (s *Source) ToConfig() sources.SourceConfig { func (s *Source) RedisClient() RedisClient { return s.Client } + +func (s *Source) RunCommand(ctx context.Context, cmds [][]any) (any, error) { + // Execute commands + responses := make([]*redis.Cmd, len(cmds)) + for i, cmd := range cmds { + responses[i] = s.RedisClient().Do(ctx, cmd...) + } + // Parse responses + out := make([]any, len(cmds)) + for i, resp := range responses { + if err := resp.Err(); err != nil { + // Add error from each command to `errSum` + errString := fmt.Sprintf("error from executing command at index %d: %s", i, err) + out[i] = errString + continue + } + val, err := resp.Result() + if err != nil { + return nil, fmt.Errorf("error getting result: %s", err) + } + out[i] = convertRedisResult(val) + } + + return out, nil +} + +// convertRedisResult recursively converts redis results (map[any]any) to be +// JSON-marshallable (map[string]any). +// It converts map[any]any to map[string]any and handles nested structures. +func convertRedisResult(v any) any { + switch val := v.(type) { + case map[any]any: + m := make(map[string]any) + for k, v := range val { + m[fmt.Sprint(k)] = convertRedisResult(v) + } + return m + case []any: + s := make([]any, len(val)) + for i, v := range val { + s[i] = convertRedisResult(v) + } + return s + default: + return v + } +} diff --git a/internal/sources/singlestore/singlestore.go b/internal/sources/singlestore/singlestore.go index 9b4d816ca9..ebcede392e 100644 --- a/internal/sources/singlestore/singlestore.go +++ b/internal/sources/singlestore/singlestore.go @@ -25,6 +25,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "go.opentelemetry.io/otel/trace" ) @@ -106,6 +107,59 @@ func (s *Source) SingleStorePool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.SingleStorePool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + defer results.Close() + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + val := rawValues[i] + if val == nil { + vMap[name] = nil + continue + } + + vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) + if err != nil { + return nil, fmt.Errorf("errors encountered when converting values: %w", err) + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initSingleStoreConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, queryTimeout string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/spanner/spanner.go b/internal/sources/spanner/spanner.go index 757921aafe..d6a6967e12 100644 --- a/internal/sources/spanner/spanner.go +++ b/internal/sources/spanner/spanner.go @@ -16,13 +16,16 @@ package spanner import ( "context" + "encoding/json" "fmt" "cloud.google.com/go/spanner" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" + "google.golang.org/api/iterator" ) const SourceKind string = "spanner" @@ -93,6 +96,79 @@ func (s *Source) DatabaseDialect() string { return s.Dialect.String() } +// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. +func processRows(iter *spanner.RowIterator) ([]any, error) { + var out []any + defer iter.Stop() + + for { + row, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + + rowMap := orderedmap.Row{} + cols := row.ColumnNames() + for i, c := range cols { + if c == "object_details" { // for list graphs or list tables + val := row.ColumnValue(i) + if val == nil { // ColumnValue returns the Cloud Spanner Value of column i, or nil for invalid column. + rowMap.Add(c, nil) + } else { + jsonString, ok := val.AsInterface().(string) + if !ok { + return nil, fmt.Errorf("column 'object_details' is not a string, but %T", val.AsInterface()) + } + var details map[string]any + if err := json.Unmarshal([]byte(jsonString), &details); err != nil { + return nil, fmt.Errorf("unable to unmarshal JSON: %w", err) + } + rowMap.Add(c, details) + } + } else { + rowMap.Add(c, row.ColumnValue(i)) + } + } + out = append(out, rowMap) + } + return out, nil +} + +func (s *Source) RunSQL(ctx context.Context, readOnly bool, statement string, params map[string]any) (any, error) { + var results []any + var err error + var opErr error + stmt := spanner.Statement{ + SQL: statement, + } + if params != nil { + stmt.Params = params + } + + if readOnly { + iter := s.SpannerClient().Single().Query(ctx, stmt) + results, opErr = processRows(iter) + } else { + _, opErr = s.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { + iter := txn.Query(ctx, stmt) + results, err = processRows(iter) + if err != nil { + return err + } + return nil + }) + } + + if opErr != nil { + return nil, fmt.Errorf("unable to execute client: %w", opErr) + } + + return results, nil +} + func initSpannerClient(ctx context.Context, tracer trace.Tracer, name, project, instance, dbname string) (*spanner.Client, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/sqlite/sqlite.go b/internal/sources/sqlite/sqlite.go index 28c5805e27..f2afc57f9c 100644 --- a/internal/sources/sqlite/sqlite.go +++ b/internal/sources/sqlite/sqlite.go @@ -17,10 +17,12 @@ package sqlite import ( "context" "database/sql" + "encoding/json" "fmt" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "go.opentelemetry.io/otel/trace" _ "modernc.org/sqlite" // Pure Go SQLite driver ) @@ -91,6 +93,66 @@ func (s *Source) SQLiteDB() *sql.DB { return s.Db } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + // Execute the SQL query with parameters + rows, err := s.SQLiteDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer rows.Close() + + // Get column names + cols, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("unable to get column names: %w", err) + } + + // The sqlite driver does not support ColumnTypes, so we can't get the + // underlying database type of the columns. We'll have to rely on the + // generic `any` type and then handle the JSON data separately. + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + // Prepare the result slice + var out []any + for rows.Next() { + if err := rows.Scan(values...); err != nil { + return nil, fmt.Errorf("unable to scan row: %w", err) + } + + // Create a map for this row + row := orderedmap.Row{} + for i, name := range cols { + val := rawValues[i] + // Handle nil values + if val == nil { + row.Add(name, nil) + continue + } + // Handle JSON data + if jsonString, ok := val.(string); ok { + var unmarshaledData any + if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil { + row.Add(name, unmarshaledData) + continue + } + } + // Store the value in the map + row.Add(name, val) + } + out = append(out, row) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating rows: %w", err) + } + + return out, nil +} + func initSQLiteConnection(ctx context.Context, tracer trace.Tracer, name, dbPath string) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/tidb/tidb.go b/internal/sources/tidb/tidb.go index b722a27524..617da6969d 100644 --- a/internal/sources/tidb/tidb.go +++ b/internal/sources/tidb/tidb.go @@ -17,6 +17,7 @@ package tidb import ( "context" "database/sql" + "encoding/json" "fmt" "regexp" @@ -104,6 +105,79 @@ func (s *Source) TiDBPool() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.TiDBPool().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + defer results.Close() + + colTypes, err := results.ColumnTypes() + if err != nil { + return nil, fmt.Errorf("unable to get column types: %w", err) + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + val := rawValues[i] + if val == nil { + vMap[name] = nil + continue + } + + // mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR" + // we'll need to cast it back to string + switch colTypes[i].DatabaseTypeName() { + case "JSON": + // unmarshal JSON data before storing to prevent double + // marshaling + byteVal, ok := val.([]byte) + if !ok { + return nil, fmt.Errorf("expected []byte for JSON column, but got %T", val) + } + var unmarshaledData any + if err := json.Unmarshal(byteVal, &unmarshaledData); err != nil { + return nil, fmt.Errorf("unable to unmarshal json data %s", val) + } + vMap[name] = unmarshaledData + case "TEXT", "VARCHAR", "NVARCHAR": + byteVal, ok := val.([]byte) + if !ok { + return nil, fmt.Errorf("expected []byte for text-like column, but got %T", val) + } + vMap[name] = string(byteVal) + default: + vMap[name] = val + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func IsTiDBCloudHost(host string) bool { pattern := `gateway\d{2}\.(.+)\.(prod|dev|staging)\.(.+)\.tidbcloud\.com` match, err := regexp.MatchString(pattern, host) diff --git a/internal/sources/trino/trino.go b/internal/sources/trino/trino.go index 98f7dbd1b6..ed99dc3c19 100644 --- a/internal/sources/trino/trino.go +++ b/internal/sources/trino/trino.go @@ -108,6 +108,56 @@ func (s *Source) TrinoDB() *sql.DB { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.TrinoDB().QueryContext(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + cols, err := results.Columns() + if err != nil { + return nil, fmt.Errorf("unable to retrieve column names: %w", err) + } + + // create an array of values for each column, which can be re-used to scan each row + rawValues := make([]any, len(cols)) + values := make([]any, len(cols)) + for i := range rawValues { + values[i] = &rawValues[i] + } + + var out []any + for results.Next() { + err := results.Scan(values...) + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, name := range cols { + val := rawValues[i] + if val == nil { + vMap[name] = nil + continue + } + + // Convert byte arrays to strings for text fields + if b, ok := val.([]byte); ok { + vMap[name] = string(b) + } else { + vMap[name] = val + } + } + out = append(out, vMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + } + + return out, nil +} + func initTrinoConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, password, catalog, schema, queryTimeout, accessToken string, kerberosEnabled, sslEnabled bool, sslCertPath, sslCert string, disableSslVerification bool) (*sql.DB, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/sources/valkey/valkey.go b/internal/sources/valkey/valkey.go index 07d1819dce..58ed2356d3 100644 --- a/internal/sources/valkey/valkey.go +++ b/internal/sources/valkey/valkey.go @@ -125,3 +125,37 @@ func (s *Source) ToConfig() sources.SourceConfig { func (s *Source) ValkeyClient() valkey.Client { return s.Client } + +func (s *Source) RunCommand(ctx context.Context, cmds [][]string) (any, error) { + // Build commands + builtCmds := make(valkey.Commands, len(cmds)) + + for i, cmd := range cmds { + builtCmds[i] = s.ValkeyClient().B().Arbitrary(cmd...).Build() + } + + if len(builtCmds) == 0 { + return nil, fmt.Errorf("no valid commands were built to execute") + } + + // Execute commands + responses := s.ValkeyClient().DoMulti(ctx, builtCmds...) + + // Parse responses + out := make([]any, len(cmds)) + for i, resp := range responses { + if err := resp.Error(); err != nil { + // Store error message in the output for this command + out[i] = fmt.Sprintf("error from executing command at index %d: %s", i, err) + continue + } + val, err := resp.ToAny() + if err != nil { + out[i] = fmt.Sprintf("error parsing response: %s", err) + continue + } + out[i] = val + } + + return out, nil +} diff --git a/internal/sources/yugabytedb/yugabytedb.go b/internal/sources/yugabytedb/yugabytedb.go index 130e43168a..830e3ae7fe 100644 --- a/internal/sources/yugabytedb/yugabytedb.go +++ b/internal/sources/yugabytedb/yugabytedb.go @@ -99,6 +99,35 @@ func (s *Source) YugabyteDBPool() *pgxpool.Pool { return s.Pool } +func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { + results, err := s.YugabyteDBPool().Query(ctx, statement, params...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + vMap := make(map[string]any) + for i, f := range fields { + vMap[f.Name] = v[i] + } + out = append(out, vMap) + } + + // this will catch actual query execution errors + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + func initYugabyteDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname, loadBalance, topologyKeys, refreshInterval, explicitFallback, failedHostTTL string) (*pgxpool.Pool, error) { //nolint:all // Reassigned ctx ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) diff --git a/internal/tools/alloydbainl/alloydbainl.go b/internal/tools/alloydbainl/alloydbainl.go index 8c3b468091..ba2ce8e14b 100644 --- a/internal/tools/alloydbainl/alloydbainl.go +++ b/internal/tools/alloydbainl/alloydbainl.go @@ -145,7 +145,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para allParamValues[i+2] = fmt.Sprintf("%s", param) } - return source.RunSQL(ctx, t.Statement, allParamValues) + resp, err := source.RunSQL(ctx, t.Statement, allParamValues) + if err != nil { + return nil, fmt.Errorf("%w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues) + } + return resp, nil } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go index 61b90a1d11..f3312acfb3 100644 --- a/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go +++ b/internal/tools/bigquery/bigqueryanalyzecontribution/bigqueryanalyzecontribution.go @@ -28,7 +28,6 @@ import ( bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-analyze-contribution" @@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryClient() *bigqueryapi.Client - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string BigQuerySession() bigqueryds.BigQuerySessionProvider + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -166,19 +165,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"]) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", "")) @@ -314,43 +303,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID) - - getInsightsQuery := bqClient.Query(getInsightsSQL) - getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}} - - job, err := getInsightsQuery.Run(ctx) - if err != nil { - return nil, fmt.Errorf("failed to execute get insights query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - - var out []any - for { - var row map[string]bigqueryapi.Value - err := it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("failed to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = value - } - out = append(out, vMap) - } - - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - return "The query returned 0 rows.", nil + connProps := []*bigqueryapi.ConnectionProperty{{Key: "session_id", Value: sessionID}} + return source.RunSQL(ctx, bqClient, getInsightsSQL, "SELECT", nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigquerycommon/conversion_test.go b/internal/tools/bigquery/bigquerycommon/conversion_test.go deleted file mode 100644 index c735d0ebe1..0000000000 --- a/internal/tools/bigquery/bigquerycommon/conversion_test.go +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bigquerycommon - -import ( - "math/big" - "reflect" - "testing" -) - -func TestNormalizeValue(t *testing.T) { - tests := []struct { - name string - input any - expected any - }{ - { - name: "big.Rat 1/3 (NUMERIC scale 9)", - input: new(big.Rat).SetFrac64(1, 3), // 0.33333333333... - expected: "0.33333333333333333333333333333333333333", // FloatString(38) - }, - { - name: "big.Rat 19/2 (9.5)", - input: new(big.Rat).SetFrac64(19, 2), - expected: "9.5", - }, - { - name: "big.Rat 12341/10 (1234.1)", - input: new(big.Rat).SetFrac64(12341, 10), - expected: "1234.1", - }, - { - name: "big.Rat 10/1 (10)", - input: new(big.Rat).SetFrac64(10, 1), - expected: "10", - }, - { - name: "string", - input: "hello", - expected: "hello", - }, - { - name: "int", - input: 123, - expected: 123, - }, - { - name: "nested slice of big.Rat", - input: []any{ - new(big.Rat).SetFrac64(19, 2), - new(big.Rat).SetFrac64(1, 4), - }, - expected: []any{"9.5", "0.25"}, - }, - { - name: "nested map of big.Rat", - input: map[string]any{ - "val1": new(big.Rat).SetFrac64(19, 2), - "val2": new(big.Rat).SetFrac64(1, 2), - }, - expected: map[string]any{ - "val1": "9.5", - "val2": "0.5", - }, - }, - { - name: "complex nested structure", - input: map[string]any{ - "list": []any{ - map[string]any{ - "rat": new(big.Rat).SetFrac64(3, 2), - }, - }, - }, - expected: map[string]any{ - "list": []any{ - map[string]any{ - "rat": "1.5", - }, - }, - }, - }, - { - name: "slice of *big.Rat", - input: []*big.Rat{ - new(big.Rat).SetFrac64(19, 2), - new(big.Rat).SetFrac64(1, 4), - }, - expected: []any{"9.5", "0.25"}, - }, - { - name: "slice of strings", - input: []string{"a", "b"}, - expected: []any{"a", "b"}, - }, - { - name: "byte slice (BYTES)", - input: []byte("hello"), - expected: []byte("hello"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := NormalizeValue(tt.input) - if !reflect.DeepEqual(got, tt.expected) { - t.Errorf("NormalizeValue() = %v, want %v", got, tt.expected) - } - }) - } -} diff --git a/internal/tools/bigquery/bigquerycommon/util.go b/internal/tools/bigquery/bigquerycommon/util.go index d9b6fd0283..5486ac36ed 100644 --- a/internal/tools/bigquery/bigquerycommon/util.go +++ b/internal/tools/bigquery/bigquerycommon/util.go @@ -17,8 +17,6 @@ package bigquerycommon import ( "context" "fmt" - "math/big" - "reflect" "sort" "strings" @@ -120,54 +118,3 @@ func InitializeDatasetParameters( return projectParam, datasetParam } - -// NormalizeValue converts BigQuery specific types to standard JSON-compatible types. -// Specifically, it handles *big.Rat (used for NUMERIC/BIGNUMERIC) by converting -// them to decimal strings with up to 38 digits of precision, trimming trailing zeros. -// It recursively handles slices (arrays) and maps (structs) using reflection. -func NormalizeValue(v any) any { - if v == nil { - return nil - } - - // Handle *big.Rat specifically. - if rat, ok := v.(*big.Rat); ok { - // Convert big.Rat to a decimal string. - // Use a precision of 38 digits (enough for BIGNUMERIC and NUMERIC) - // and trim trailing zeros to match BigQuery's behavior. - s := rat.FloatString(38) - if strings.Contains(s, ".") { - s = strings.TrimRight(s, "0") - s = strings.TrimRight(s, ".") - } - return s - } - - // Use reflection for slices and maps to handle various underlying types. - rv := reflect.ValueOf(v) - switch rv.Kind() { - case reflect.Slice, reflect.Array: - // Preserve []byte as is, so json.Marshal encodes it as Base64 string (BigQuery BYTES behavior). - if rv.Type().Elem().Kind() == reflect.Uint8 { - return v - } - newSlice := make([]any, rv.Len()) - for i := 0; i < rv.Len(); i++ { - newSlice[i] = NormalizeValue(rv.Index(i).Interface()) - } - return newSlice - case reflect.Map: - // Ensure keys are strings to produce a JSON-compatible map. - if rv.Type().Key().Kind() != reflect.String { - return v - } - newMap := make(map[string]any, rv.Len()) - iter := rv.MapRange() - for iter.Next() { - newMap[iter.Key().String()] = NormalizeValue(iter.Value().Interface()) - } - return newMap - } - - return v -} diff --git a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go index a70d4d342d..3e248e1971 100644 --- a/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go +++ b/internal/tools/bigquery/bigqueryexecutesql/bigqueryexecutesql.go @@ -27,10 +27,8 @@ import ( "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-execute-sql" @@ -53,11 +51,11 @@ type compatibleSource interface { BigQueryClient() *bigqueryapi.Client BigQuerySession() bigqueryds.BigQuerySessionProvider BigQueryWriteMode() string - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -169,19 +167,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } var connProps []*bigqueryapi.ConnectionProperty @@ -283,61 +271,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return "Dry run was requested, but no job information was returned.", nil } - query := bqClient.Query(sql) - query.Location = bqClient.Location - - query.ConnectionProperties = connProps - // Log the query executed for debugging. logger, err := util.LoggerFromContext(ctx) if err != nil { return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - var out []any - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - for { - var val []bigqueryapi.Value - err = it.Next(&val) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - schema := it.Schema - row := orderedmap.Row{} - for i, field := range schema { - row.Add(field.Name, bqutil.NormalizeValue(val[i])) - } - out = append(out, row) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - if statementType == "SELECT" { - return "The query returned 0 rows.", nil - } - // This is the fallback for a successful query that doesn't return content. - // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return "Query executed successfully and returned no content.", nil + return source.RunSQL(ctx, bqClient, sql, statementType, nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go index 034bce3501..b316aead3f 100644 --- a/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go +++ b/internal/tools/bigquery/bigqueryforecast/bigqueryforecast.go @@ -28,7 +28,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-forecast" @@ -49,12 +48,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryClient() *bigqueryapi.Client - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string BigQuerySession() bigqueryds.BigQuerySessionProvider + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -173,19 +172,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } var historyDataSource string @@ -251,7 +240,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para idColsFormatted := fmt.Sprintf("['%s']", strings.Join(idCols, "', '")) idColsArg = fmt.Sprintf(", id_cols => %s", idColsFormatted) } - sql := fmt.Sprintf(`SELECT * FROM AI.FORECAST( %s, @@ -260,16 +248,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para horizon => %d%s)`, historyDataSource, dataCol, timestampCol, horizon, idColsArg) - // JobStatistics.QueryStatistics.StatementType - query := bqClient.Query(sql) - query.Location = bqClient.Location session, err := source.BigQuerySession()(ctx) if err != nil { return nil, fmt.Errorf("failed to get BigQuery session: %w", err) } + var connProps []*bigqueryapi.ConnectionProperty if session != nil { // Add session ID to the connection properties for subsequent calls. - query.ConnectionProperties = []*bigqueryapi.ConnectionProperty{ + connProps = []*bigqueryapi.ConnectionProperty{ {Key: "session_id", Value: session.ID}, } } @@ -281,40 +267,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - var out []any - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - for { - var row map[string]bigqueryapi.Value - err = it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = value - } - out = append(out, vMap) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - return "The query returned 0 rows.", nil + return source.RunSQL(ctx, bqClient, sql, "SELECT", nil, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go index b083c49e2c..545850066e 100644 --- a/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go +++ b/internal/tools/bigquery/bigquerygetdatasetinfo/bigquerygetdatasetinfo.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" ) const kind string = "bigquery-get-dataset-info" @@ -47,11 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -138,18 +137,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey) } - bqClient := source.BigQueryClient() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } if !source.IsDatasetAllowed(projectId, datasetId) { diff --git a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go index b896244ed0..4cfc91e55b 100644 --- a/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go +++ b/internal/tools/bigquery/bigquerygettableinfo/bigquerygettableinfo.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" ) const kind string = "bigquery-get-table-info" @@ -48,11 +48,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -151,18 +150,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := source.BigQueryClient() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } dsHandle := bqClient.DatasetInProject(projectId, datasetId) diff --git a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go index dafe9b2246..93663f4f45 100644 --- a/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go +++ b/internal/tools/bigquery/bigquerylistdatasetids/bigquerylistdatasetids.go @@ -21,9 +21,9 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -46,10 +46,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigQueryProject() string - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -135,17 +134,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey) } - bqClient := source.BigQueryClient() - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } datasetIterator := bqClient.Datasets(ctx) datasetIterator.ProjectID = projectId diff --git a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go index 11987c6dac..e3f609f522 100644 --- a/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go +++ b/internal/tools/bigquery/bigquerylisttableids/bigquerylisttableids.go @@ -21,10 +21,10 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" + bigqueryrestapi "google.golang.org/api/bigquery/v2" "google.golang.org/api/iterator" ) @@ -47,12 +47,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - BigQueryClient() *bigqueryapi.Client - BigQueryClientCreator() bigqueryds.BigqueryClientCreator BigQueryProject() string UseClientAuthorization() bool IsDatasetAllowed(projectID, datasetID string) bool BigQueryAllowedDatasets() []string + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) } type Config struct { @@ -145,17 +144,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId) } - bqClient := source.BigQueryClient() - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, _, err = source.BigQueryClientCreator()(tokenStr, false) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } + bqClient, _, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err } dsHandle := bqClient.DatasetInProject(projectId, datasetId) diff --git a/internal/tools/bigquery/bigquerysql/bigquerysql.go b/internal/tools/bigquery/bigquerysql/bigquerysql.go index fa02f658eb..ff433a9ed5 100644 --- a/internal/tools/bigquery/bigquerysql/bigquerysql.go +++ b/internal/tools/bigquery/bigquerysql/bigquerysql.go @@ -23,13 +23,11 @@ import ( bigqueryapi "cloud.google.com/go/bigquery" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" - "github.com/googleapis/genai-toolbox/internal/util/parameters" - bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery" "github.com/googleapis/genai-toolbox/internal/tools" bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon" + "github.com/googleapis/genai-toolbox/internal/util/parameters" bigqueryrestapi "google.golang.org/api/bigquery/v2" - "google.golang.org/api/iterator" ) const kind string = "bigquery-sql" @@ -49,12 +47,10 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - BigQueryClient() *bigqueryapi.Client BigQuerySession() bigqueryds.BigQuerySessionProvider - BigQueryWriteMode() string - BigQueryRestService() *bigqueryrestapi.Service - BigQueryClientCreator() bigqueryds.BigqueryClientCreator UseClientAuthorization() bool + RetrieveClientAndService(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) + RunSQL(context.Context, *bigqueryapi.Client, string, string, []bigqueryapi.QueryParameter, []*bigqueryapi.ConnectionProperty) (any, error) } type Config struct { @@ -189,25 +185,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para lowLevelParams = append(lowLevelParams, lowLevelParam) } - bqClient := source.BigQueryClient() - restService := source.BigQueryRestService() - - // Initialize new client if using user OAuth token - if source.UseClientAuthorization() { - tokenStr, err := accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - bqClient, restService, err = source.BigQueryClientCreator()(tokenStr, true) - if err != nil { - return nil, fmt.Errorf("error creating client from OAuth access token: %w", err) - } - } - - query := bqClient.Query(newStatement) - query.Parameters = highLevelParams - query.Location = bqClient.Location - connProps := []*bigqueryapi.ConnectionProperty{} if source.BigQuerySession() != nil { session, err := source.BigQuerySession()(ctx) @@ -219,57 +196,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para connProps = append(connProps, &bigqueryapi.ConnectionProperty{Key: "session_id", Value: session.ID}) } } - query.ConnectionProperties = connProps - dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), query.Location, newStatement, lowLevelParams, connProps) + + bqClient, restService, err := source.RetrieveClientAndService(accessToken) + if err != nil { + return nil, err + } + + dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, connProps) if err != nil { return nil, fmt.Errorf("query validation failed: %w", err) } statementType := dryRunJob.Statistics.Query.StatementType - // This block handles SELECT statements, which return a row set. - // We iterate through the results, convert each row into a map of - // column names to values, and return the collection of rows. - job, err := query.Run(ctx) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - it, err := job.Read(ctx) - if err != nil { - return nil, fmt.Errorf("unable to read query results: %w", err) - } - - var out []any - for { - var row map[string]bigqueryapi.Value - err = it.Next(&row) - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to iterate through query results: %w", err) - } - vMap := make(map[string]any) - for key, value := range row { - vMap[key] = bqutil.NormalizeValue(value) - } - out = append(out, vMap) - } - // If the query returned any rows, return them directly. - if len(out) > 0 { - return out, nil - } - - // This handles the standard case for a SELECT query that successfully - // executes but returns zero rows. - if statementType == "SELECT" { - return "The query returned 0 rows.", nil - } - // This is the fallback for a successful query that doesn't return content. - // In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc. - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return "Query executed successfully and returned no content.", nil + return source.RunSQL(ctx, bqClient, newStatement, statementType, highLevelParams, connProps) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/bigtable/bigtable.go b/internal/tools/bigtable/bigtable.go index fe93630f95..f8b576b381 100644 --- a/internal/tools/bigtable/bigtable.go +++ b/internal/tools/bigtable/bigtable.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { BigtableClient() *bigtable.Client + RunSQL(context.Context, string, parameters.Parameters, parameters.ParamValues) (any, error) } type Config struct { @@ -95,45 +96,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -func getBigtableType(paramType string) (bigtable.SQLType, error) { - switch paramType { - case "boolean": - return bigtable.BoolSQLType{}, nil - case "string": - return bigtable.StringSQLType{}, nil - case "integer": - return bigtable.Int64SQLType{}, nil - case "float": - return bigtable.Float64SQLType{}, nil - case "array": - return bigtable.ArraySQLType{}, nil - default: - return nil, fmt.Errorf("unknow param type %s", paramType) - } -} - -func getMapParamsType(tparams parameters.Parameters, params parameters.ParamValues) (map[string]bigtable.SQLType, error) { - btParamTypes := make(map[string]bigtable.SQLType) - for _, p := range tparams { - if p.GetType() == "array" { - itemType, err := getBigtableType(p.Manifest().Items.Type) - if err != nil { - return nil, err - } - btParamTypes[p.GetName()] = bigtable.ArraySQLType{ - ElemType: itemType, - } - continue - } - paramType, err := getBigtableType(p.GetType()) - if err != nil { - return nil, err - } - btParamTypes[p.GetName()] = paramType - } - return btParamTypes, nil -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -150,46 +112,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - - mapParamsType, err := getMapParamsType(t.Parameters, newParams) - if err != nil { - return nil, fmt.Errorf("fail to get map params: %w", err) - } - - ps, err := source.BigtableClient().PrepareStatement( - ctx, - newStatement, - mapParamsType, - ) - if err != nil { - return nil, fmt.Errorf("unable to prepare statement: %w", err) - } - - bs, err := ps.Bind(newParams.AsMap()) - if err != nil { - return nil, fmt.Errorf("unable to bind: %w", err) - } - - var out []any - err = bs.Execute(ctx, func(resultRow bigtable.ResultRow) bool { - vMap := make(map[string]any) - cols := resultRow.Metadata.Columns - - for _, c := range cols { - var columValue any - err = resultRow.GetByName(c.Name, &columValue) - vMap[c.Name] = columValue - } - - out = append(out, vMap) - - return true - }) - if err != nil { - return nil, fmt.Errorf("unable to execute client: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, t.Parameters, newParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cassandra/cassandracql/cassandracql.go b/internal/tools/cassandra/cassandracql/cassandracql.go index a05d0815ba..b0a95c4db1 100644 --- a/internal/tools/cassandra/cassandracql/cassandracql.go +++ b/internal/tools/cassandra/cassandracql/cassandracql.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { CassandraSession() *gocql.Session + RunSQL(context.Context, string, parameters.ParamValues) (any, error) } type Config struct { @@ -121,25 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - sliceParams := newParams.AsSlice() - iter := source.CassandraSession().Query(newStatement, sliceParams...).IterContext(ctx) - - // Create a slice to store the out - var out []map[string]interface{} - - // Scan results into a map and append to the slice - for { - row := make(map[string]interface{}) // Create a new map for each row - if !iter.MapScan(row) { - break // No more rows - } - out = append(out, row) - } - - if err := iter.Close(); err != nil { - return nil, fmt.Errorf("unable to parse rows: %w", err) - } - return out, nil + return source.RunSQL(ctx, newStatement, newParams) } // Manifest implements tools.Tool. diff --git a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go index 826d20d482..6ea0f89759 100644 --- a/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go +++ b/internal/tools/clickhouse/clickhouseexecutesql/clickhouseexecutesql.go @@ -16,7 +16,6 @@ package clickhouse import ( "context" - "database/sql" "fmt" yaml "github.com/goccy/go-yaml" @@ -42,7 +41,7 @@ func newExecuteSQLConfig(ctx context.Context, name string, decoder *yaml.Decoder } type compatibleSource interface { - ClickHousePool() *sql.DB + RunSQL(context.Context, string, parameters.ParamValues) (any, error) } type Config struct { @@ -98,63 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("unable to cast sql parameter %s", paramsMap["sql"]) } - - results, err := source.ClickHousePool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - // ClickHouse driver may return specific types that need handling - switch colTypes[i].DatabaseTypeName() { - case "String", "FixedString": - if rawValues[i] != nil { - // Handle potential []byte to string conversion if needed - if b, ok := rawValues[i].([]byte); ok { - vMap[name] = string(b) - } else { - vMap[name] = rawValues[i] - } - } else { - vMap[name] = nil - } - default: - vMap[name] = rawValues[i] - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered by results.Scan: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go index e6df548907..daeab033ed 100644 --- a/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go +++ b/internal/tools/clickhouse/clickhouselistdatabases/clickhouselistdatabases.go @@ -16,7 +16,6 @@ package clickhouse import ( "context" - "database/sql" "fmt" yaml "github.com/goccy/go-yaml" @@ -42,7 +41,7 @@ func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Deco } type compatibleSource interface { - ClickHousePool() *sql.DB + RunSQL(context.Context, string, parameters.ParamValues) (any, error) } type Config struct { @@ -95,29 +94,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para // Query to list all databases query := "SHOW DATABASES" - results, err := source.ClickHousePool().QueryContext(ctx, query) + out, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - var databases []map[string]any - for results.Next() { - var dbName string - err := results.Scan(&dbName) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - databases = append(databases, map[string]any{ - "name": dbName, - }) + return nil, err } - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered by results.Scan: %w", err) - } - - return databases, nil + return out, nil } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go index e882a88ea5..2e2da1a02d 100644 --- a/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go +++ b/internal/tools/clickhouse/clickhouselisttables/clickhouselisttables.go @@ -16,7 +16,6 @@ package clickhouse import ( "context" - "database/sql" "fmt" yaml "github.com/goccy/go-yaml" @@ -43,7 +42,7 @@ func newListTablesConfig(ctx context.Context, name string, decoder *yaml.Decoder } type compatibleSource interface { - ClickHousePool() *sql.DB + RunSQL(context.Context, string, parameters.ParamValues) (any, error) } type Config struct { @@ -101,33 +100,27 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", databaseKey) } - // Query to list all tables in the specified database query := fmt.Sprintf("SHOW TABLES FROM %s", database) - results, err := source.ClickHousePool().QueryContext(ctx, query) + out, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer results.Close() - tables := []map[string]any{} - for results.Next() { - var tableName string - err := results.Scan(&tableName) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) + res, ok := out.([]any) + if !ok { + return nil, fmt.Errorf("unable to convert result to list") + } + var tables []map[string]any + for _, item := range res { + tableMap, ok := item.(map[string]any) + if !ok { + return nil, fmt.Errorf("unexpected type in result: got %T, want map[string]any", item) } - tables = append(tables, map[string]any{ - "name": tableName, - "database": database, - }) + tableMap["database"] = database + tables = append(tables, tableMap) } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered by results.Scan: %w", err) - } - return tables, nil } diff --git a/internal/tools/clickhouse/clickhousesql/clickhousesql.go b/internal/tools/clickhouse/clickhousesql/clickhousesql.go index 83a2f1ee9d..d48825439a 100644 --- a/internal/tools/clickhouse/clickhousesql/clickhousesql.go +++ b/internal/tools/clickhouse/clickhousesql/clickhousesql.go @@ -16,7 +16,6 @@ package clickhouse import ( "context" - "database/sql" "fmt" yaml "github.com/goccy/go-yaml" @@ -42,7 +41,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - ClickHousePool() *sql.DB + RunSQL(context.Context, string, parameters.ParamValues) (any, error) } type Config struct { @@ -105,65 +104,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params: %w", err) } - sliceParams := newParams.AsSlice() - results, err := source.ClickHousePool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - switch colTypes[i].DatabaseTypeName() { - case "String", "FixedString": - if rawValues[i] != nil { - // Handle potential []byte to string conversion if needed - if b, ok := rawValues[i].([]byte); ok { - vMap[name] = string(b) - } else { - vMap[name] = rawValues[i] - } - } else { - vMap[name] = nil - } - default: - vMap[name] = rawValues[i] - } - } - out = append(out, vMap) - } - - err = results.Close() - if err != nil { - return nil, fmt.Errorf("unable to close rows: %w", err) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered by results.Scan: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, newParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudgda/cloudgda.go b/internal/tools/cloudgda/cloudgda.go index bf54c26c3f..f8c9c1ea22 100644 --- a/internal/tools/cloudgda/cloudgda.go +++ b/internal/tools/cloudgda/cloudgda.go @@ -15,12 +15,9 @@ package cloudgda import ( - "bytes" "context" "encoding/json" "fmt" - "io" - "net/http" "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" @@ -46,9 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetProjectID() string - GetBaseURL() string UseClientAuthorization() bool - GetClient(context.Context, string) (*http.Client, error) + RunQuery(context.Context, string, []byte) (any, error) } type Config struct { @@ -113,10 +109,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("prompt parameter not found or not a string") } - // The API endpoint itself always uses the "global" location. - apiLocation := "global" - apiParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), apiLocation) - apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", source.GetBaseURL(), apiParent) + // Parse the access token if provided + var tokenStr string + if source.UseClientAuthorization() { + var err error + tokenStr, err = accessToken.ParseBearerToken() + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } + } // The parent in the request payload uses the tool's configured location. payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location) @@ -132,49 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("failed to marshal request payload: %w", err) } - - // Parse the access token if provided - var tokenStr string - if source.UseClientAuthorization() { - var err error - tokenStr, err = accessToken.ParseBearerToken() - if err != nil { - return nil, fmt.Errorf("error parsing access token: %w", err) - } - } - - client, err := source.GetClient(ctx, tokenStr) - if err != nil { - return nil, fmt.Errorf("failed to get HTTP client: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody)) - } - - var result map[string]any - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) - } - - return result, nil + return source.RunQuery(ctx, tokenStr, bodyBytes) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudmonitoring/cloudmonitoring.go b/internal/tools/cloudmonitoring/cloudmonitoring.go index 54c19f6774..1f3a6127ea 100644 --- a/internal/tools/cloudmonitoring/cloudmonitoring.go +++ b/internal/tools/cloudmonitoring/cloudmonitoring.go @@ -16,9 +16,7 @@ package cloudmonitoring import ( "context" - "encoding/json" "fmt" - "io" "net/http" "github.com/goccy/go-yaml" @@ -44,9 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - BaseURL() string Client() *http.Client - UserAgent() string + RunQuery(projectID, query string) (any, error) } type Config struct { @@ -110,45 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("query parameter not found or not a string") } - - url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", source.BaseURL(), projectID) - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } - - q := req.URL.Query() - q.Add("query", query) - req.URL.RawQuery = q.Encode() - - req.Header.Set("User-Agent", source.UserAgent()) - - resp, err := source.Client().Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("request failed: %s, body: %s", resp.Status, string(body)) - } - - if len(body) == 0 { - return nil, nil - } - - var result map[string]any - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to unmarshal json: %w, body: %s", err, string(body)) - } - - return result, nil + return source.RunQuery(projectID, query) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go index e8f7431f8b..29516fd8b1 100644 --- a/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go +++ b/internal/tools/cloudsql/cloudsqlcloneinstance/cloudsqlcloneinstance.go @@ -45,6 +45,7 @@ type compatibleSource interface { GetDefaultProject() string GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CloneInstance(context.Context, string, string, string, string, string, string, string) (any, error) } // Config defines the configuration for the clone-instance tool. @@ -142,38 +143,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error casting 'destinationInstanceName' parameter: %v", paramsMap["destinationInstanceName"]) } - cloneContext := &sqladmin.CloneContext{ - DestinationInstanceName: destinationInstanceName, - } + pointInTime, _ := paramsMap["pointInTime"].(string) + preferredZone, _ := paramsMap["preferredZone"].(string) + preferredSecondaryZone, _ := paramsMap["preferredSecondaryZone"].(string) - pointInTime, ok := paramsMap["pointInTime"].(string) - if ok { - cloneContext.PointInTime = pointInTime - } - preferredZone, ok := paramsMap["preferredZone"].(string) - if ok { - cloneContext.PreferredZone = preferredZone - } - preferredSecondaryZone, ok := paramsMap["preferredSecondaryZone"].(string) - if ok { - cloneContext.PreferredSecondaryZone = preferredSecondaryZone - } - - rb := &sqladmin.InstancesCloneRequest{ - CloneContext: cloneContext, - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.Clone(project, sourceInstanceName, rb).Do() - if err != nil { - return nil, fmt.Errorf("error cloning instance: %w", err) - } - - return resp, nil + return source.CloneInstance(ctx, project, sourceInstanceName, destinationInstanceName, pointInTime, preferredZone, preferredSecondaryZone, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go index 57b4cc06d6..1cbc62db24 100644 --- a/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go +++ b/internal/tools/cloudsql/cloudsqlcreatedatabase/cloudsqlcreatedatabase.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - sqladmin "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-create-database" @@ -43,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CreateDatabase(context.Context, string, string, string, string) (any, error) } // Config defines the configuration for the create-database tool. @@ -137,24 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("missing 'name' parameter") } - - database := sqladmin.Database{ - Name: name, - Project: project, - Instance: instance, - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Databases.Insert(project, instance, &database).Do() - if err != nil { - return nil, fmt.Errorf("error creating database: %w", err) - } - - return resp, nil + return source.CreateDatabase(ctx, name, project, instance, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go index 148ccfeb6c..c07c116194 100644 --- a/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go +++ b/internal/tools/cloudsql/cloudsqlcreateusers/cloudsqlcreateusers.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - sqladmin "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-create-users" @@ -43,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CreateUsers(context.Context, string, string, string, string, bool, string) (any, error) } // Config defines the configuration for the create-user tool. @@ -141,33 +140,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } iamUser, _ := paramsMap["iamUser"].(bool) - - user := sqladmin.User{ - Name: name, - } - - if iamUser { - user.Type = "CLOUD_IAM_USER" - } else { - user.Type = "BUILT_IN" - password, ok := paramsMap["password"].(string) - if !ok || password == "" { - return nil, fmt.Errorf("missing 'password' parameter for non-IAM user") - } - user.Password = password - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Users.Insert(project, instance, &user).Do() - if err != nil { - return nil, fmt.Errorf("error creating user: %w", err) - } - - return resp, nil + password, _ := paramsMap["password"].(string) + return source.CreateUsers(ctx, project, instance, name, password, iamUser, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go index 1fb40b67bc..e41b52ed03 100644 --- a/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go +++ b/internal/tools/cloudsql/cloudsqlgetinstances/cloudsqlgetinstances.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-get-instance" @@ -43,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + GetInstance(context.Context, string, string, string) (any, error) } // Config defines the configuration for the get-instances tool. @@ -133,18 +132,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("missing 'instanceId' parameter") } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.Get(projectId, instanceId).Do() - if err != nil { - return nil, fmt.Errorf("error getting instance: %w", err) - } - - return resp, nil + return source.GetInstance(ctx, projectId, instanceId, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go index ba54380631..a04da5dce5 100644 --- a/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go +++ b/internal/tools/cloudsql/cloudsqllistdatabases/cloudsqllistdatabases.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-databases" @@ -43,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + ListDatabase(context.Context, string, string, string) (any, error) } // Config defines the configuration for the list-databases tool. @@ -132,37 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("missing 'instance' parameter") } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Databases.List(project, instance).Do() - if err != nil { - return nil, fmt.Errorf("error listing databases: %w", err) - } - - if resp.Items == nil { - return []any{}, nil - } - - type databaseInfo struct { - Name string `json:"name"` - Charset string `json:"charset"` - Collation string `json:"collation"` - } - - var databases []databaseInfo - for _, item := range resp.Items { - databases = append(databases, databaseInfo{ - Name: item.Name, - Charset: item.Charset, - Collation: item.Collation, - }) - } - - return databases, nil + return source.ListDatabase(ctx, project, instance, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go index 11ccd91bad..dc2cc5b8af 100644 --- a/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go +++ b/internal/tools/cloudsql/cloudsqllistinstances/cloudsqllistinstances.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-list-instances" @@ -43,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + ListInstance(context.Context, string, string) (any, error) } // Config defines the configuration for the list-instance tool. @@ -127,35 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("missing 'project' parameter") } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.List(project).Do() - if err != nil { - return nil, fmt.Errorf("error listing instances: %w", err) - } - - if resp.Items == nil { - return []any{}, nil - } - - type instanceInfo struct { - Name string `json:"name"` - InstanceType string `json:"instanceType"` - } - - var instances []instanceInfo - for _, item := range resp.Items { - instances = append(instances, instanceInfo{ - Name: item.Name, - InstanceType: item.InstanceType, - }) - } - - return instances, nil + return source.ListInstance(ctx, project, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go index 672f999282..2a7472cf93 100644 --- a/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go +++ b/internal/tools/cloudsql/cloudsqlwaitforoperation/cloudsqlwaitforoperation.go @@ -16,11 +16,7 @@ package cloudsqlwaitforoperation import ( "context" - "encoding/json" "fmt" - "regexp" - "strings" - "text/template" "time" yaml "github.com/goccy/go-yaml" @@ -91,6 +87,7 @@ type compatibleSource interface { GetDefaultProject() string GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + GetWaitForOperations(context.Context, *sqladmin.Service, string, string, string, time.Duration) (any, error) } // Config defines the configuration for the wait-for-operation tool. @@ -229,14 +226,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("missing 'operation' parameter") } + ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) + defer cancel() + service, err := source.GetService(ctx, string(accessToken)) if err != nil { return nil, err } - ctx, cancel := context.WithTimeout(ctx, 30*time.Minute) - defer cancel() - delay := t.Delay maxDelay := t.MaxDelay multiplier := t.Multiplier @@ -250,37 +247,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para default: } - op, err := service.Operations.Get(project, operationID).Do() + op, err := source.GetWaitForOperations(ctx, service, project, operationID, cloudSQLConnectionMessageTemplate, delay) if err != nil { - fmt.Printf("error getting operation: %s, retrying in %v\n", err, delay) - } else { - if op.Status == "DONE" { - if op.Error != nil { - var errorBytes []byte - errorBytes, err = json.Marshal(op.Error) - if err != nil { - return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err) - } - return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes)) - } - - var opBytes []byte - opBytes, err = op.MarshalJSON() - if err != nil { - return nil, fmt.Errorf("could not marshal operation: %w", err) - } - - var data map[string]any - if err := json.Unmarshal(opBytes, &data); err != nil { - return nil, fmt.Errorf("could not unmarshal operation: %w", err) - } - - if msg, ok := t.generateCloudSQLConnectionMessage(source, data); ok { - return msg, nil - } - return string(opBytes), nil - } - fmt.Printf("Operation not complete, retrying in %v\n", delay) + return nil, err + } else if op != nil { + return op, nil } time.Sleep(delay) @@ -321,105 +292,6 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo return source.UseClientAuthorization(), nil } -func (t Tool) generateCloudSQLConnectionMessage(source compatibleSource, opResponse map[string]any) (string, bool) { - operationType, ok := opResponse["operationType"].(string) - if !ok || operationType != "CREATE_DATABASE" { - return "", false - } - - targetLink, ok := opResponse["targetLink"].(string) - if !ok { - return "", false - } - - r := regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`) - matches := r.FindStringSubmatch(targetLink) - if len(matches) < 4 { - return "", false - } - project := matches[1] - instance := matches[2] - database := matches[3] - - instanceData, err := t.fetchInstanceData(context.Background(), source, project, instance) - if err != nil { - fmt.Printf("error fetching instance data: %v\n", err) - return "", false - } - - region, ok := instanceData["region"].(string) - if !ok { - return "", false - } - - databaseVersion, ok := instanceData["databaseVersion"].(string) - if !ok { - return "", false - } - - var dbType string - if strings.Contains(databaseVersion, "POSTGRES") { - dbType = "postgres" - } else if strings.Contains(databaseVersion, "MYSQL") { - dbType = "mysql" - } else if strings.Contains(databaseVersion, "SQLSERVER") { - dbType = "mssql" - } else { - return "", false - } - - tmpl, err := template.New("cloud-sql-connection").Parse(cloudSQLConnectionMessageTemplate) - if err != nil { - return fmt.Sprintf("template parsing error: %v", err), false - } - - data := struct { - Project string - Region string - Instance string - DBType string - DBTypeUpper string - Database string - }{ - Project: project, - Region: region, - Instance: instance, - DBType: dbType, - DBTypeUpper: strings.ToUpper(dbType), - Database: database, - } - - var b strings.Builder - if err := tmpl.Execute(&b, data); err != nil { - return fmt.Sprintf("template execution error: %v", err), false - } - - return b.String(), true -} - -func (t Tool) fetchInstanceData(ctx context.Context, source compatibleSource, project, instance string) (map[string]any, error) { - service, err := source.GetService(ctx, "") - if err != nil { - return nil, err - } - - resp, err := service.Instances.Get(project, instance).Do() - if err != nil { - return nil, fmt.Errorf("error getting instance: %w", err) - } - - var data map[string]any - var b []byte - b, err = resp.MarshalJSON() - if err != nil { - return nil, fmt.Errorf("error marshalling response: %w", err) - } - if err := json.Unmarshal(b, &data); err != nil { - return nil, fmt.Errorf("error unmarshalling response body: %w", err) - } - return data, nil -} - func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { return "Authorization", nil } diff --git a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go index 78bc77d6fa..24ac142dd1 100644 --- a/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go +++ b/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance/cloudsqlmssqlcreateinstance.go @@ -23,7 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - sqladmin "google.golang.org/api/sqladmin/v1" + "google.golang.org/api/sqladmin/v1" ) const kind string = "cloud-sql-mssql-create-instance" @@ -44,8 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CreateInstance(context.Context, string, string, string, string, sqladmin.Settings, string) (any, error) } // Config defines the configuration for the create-instances tool. @@ -148,7 +148,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("error casting 'editionPreset' parameter: %s", paramsMap["editionPreset"]) } - settings := sqladmin.Settings{} switch strings.ToLower(editionPreset) { case "production": @@ -166,26 +165,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para default: return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) } - - instance := sqladmin.DatabaseInstance{ - Name: name, - DatabaseVersion: dbVersion, - RootPassword: rootPassword, - Settings: &settings, - Project: project, - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.Insert(project, &instance).Do() - if err != nil { - return nil, fmt.Errorf("error creating instance: %w", err) - } - - return resp, nil + return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go index 165a057c35..c23926229e 100644 --- a/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go +++ b/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance/cloudsqlmysqlcreateinstance.go @@ -44,8 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CreateInstance(context.Context, string, string, string, string, sqladmin.Settings, string) (any, error) } // Config defines the configuration for the create-instances tool. @@ -167,25 +167,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) } - instance := sqladmin.DatabaseInstance{ - Name: name, - DatabaseVersion: dbVersion, - RootPassword: rootPassword, - Settings: &settings, - Project: project, - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.Insert(project, &instance).Do() - if err != nil { - return nil, fmt.Errorf("error creating instance: %w", err) - } - - return resp, nil + return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go index 224cc3700c..0248a2e6c9 100644 --- a/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go +++ b/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances/cloudsqlpgcreateinstances.go @@ -44,8 +44,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { GetDefaultProject() string - GetService(context.Context, string) (*sqladmin.Service, error) UseClientAuthorization() bool + CreateInstance(context.Context, string, string, string, string, sqladmin.Settings, string) (any, error) } // Config defines the configuration for the create-instances tool. @@ -166,26 +166,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para default: return nil, fmt.Errorf("invalid 'editionPreset': %q. Must be either 'Production' or 'Development'", editionPreset) } - - instance := sqladmin.DatabaseInstance{ - Name: name, - DatabaseVersion: dbVersion, - RootPassword: rootPassword, - Settings: &settings, - Project: project, - } - - service, err := source.GetService(ctx, string(accessToken)) - if err != nil { - return nil, err - } - - resp, err := service.Instances.Insert(project, &instance).Do() - if err != nil { - return nil, fmt.Errorf("error creating instance: %w", err) - } - - return resp, nil + return source.CreateInstance(ctx, project, name, dbVersion, rootPassword, settings, string(accessToken)) } // ParseParams parses the parameters for the tool. diff --git a/internal/tools/couchbase/couchbase.go b/internal/tools/couchbase/couchbase.go index 481c9f6b22..a04a3b0aee 100644 --- a/internal/tools/couchbase/couchbase.go +++ b/internal/tools/couchbase/couchbase.go @@ -16,7 +16,6 @@ package couchbase import ( "context" - "encoding/json" "fmt" "github.com/couchbase/gocb/v2" @@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { CouchbaseScope() *gocb.Scope - CouchbaseQueryScanConsistency() uint + RunSQL(string, parameters.ParamValues) (any, error) } type Config struct { @@ -112,24 +111,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - results, err := source.CouchbaseScope().Query(newStatement, &gocb.QueryOptions{ - ScanConsistency: gocb.QueryScanConsistency(source.CouchbaseQueryScanConsistency()), - NamedParameters: newParams.AsMap(), - }) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - var out []any - for results.Next() { - var result json.RawMessage - err := results.Row(&result) - if err != nil { - return nil, fmt.Errorf("error processing row: %w", err) - } - out = append(out, result) - } - return out, nil + return source.RunSQL(newStatement, newParams) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/dgraph/dgraph.go b/internal/tools/dgraph/dgraph.go index beef9f86a5..2ed68fe209 100644 --- a/internal/tools/dgraph/dgraph.go +++ b/internal/tools/dgraph/dgraph.go @@ -16,7 +16,6 @@ package dgraph import ( "context" - "encoding/json" "fmt" yaml "github.com/goccy/go-yaml" @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { DgraphClient() *dgraph.DgraphClient + RunSQL(string, parameters.ParamValues, bool, string) (any, error) } type Config struct { @@ -95,27 +95,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - paramsMap := params.AsMapWithDollarPrefix() - - resp, err := source.DgraphClient().ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout) - if err != nil { - return nil, err - } - - if err := dgraph.CheckError(resp); err != nil { - return nil, err - } - - var result struct { - Data map[string]interface{} `json:"data"` - } - - if err := json.Unmarshal(resp, &result); err != nil { - return nil, fmt.Errorf("error parsing JSON: %v", err) - } - - return result.Data, nil + return source.RunSQL(t.Statement, params, t.IsQuery, t.Timeout) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go index d7cbb35722..57f78a4403 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql.go @@ -15,14 +15,10 @@ package elasticsearchesql import ( - "bytes" "context" - "encoding/json" "fmt" "time" - "github.com/elastic/go-elasticsearch/v9/esapi" - "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/goccy/go-yaml" @@ -41,6 +37,7 @@ func init() { type compatibleSource interface { ElasticsearchClient() es.EsClient + RunSQL(ctx context.Context, format, query string, params []map[string]any) (any, error) } type Config struct { @@ -91,16 +88,6 @@ func (t Tool) ToConfig() tools.ToolConfig { return t.Config } -type esqlColumn struct { - Name string `json:"name"` - Type string `json:"type"` -} - -type esqlResult struct { - Columns []esqlColumn `json:"columns"` - Values [][]any `json:"values"` -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -116,20 +103,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para defer cancel() } - bodyStruct := struct { - Query string `json:"query"` - Params []map[string]any `json:"params,omitempty"` - }{ - Query: t.Query, - Params: make([]map[string]any, 0, len(params)), - } - + query := t.Query + sqlParams := make([]map[string]any, 0, len(params)) paramMap := params.AsMap() - // If a query is provided in the params and not already set in the tool, use it. - if query, ok := paramMap["query"]; ok { - if str, ok := query.(string); ok && bodyStruct.Query == "" { - bodyStruct.Query = str + if queryVal, ok := paramMap["query"]; ok { + if str, ok := queryVal.(string); ok && t.Query == "" { + query = str } // Drop the query param if not a string or if the tool already has a query. @@ -140,65 +120,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if param.GetType() == "array" { return nil, fmt.Errorf("array parameters are not supported yet") } - bodyStruct.Params = append(bodyStruct.Params, map[string]any{param.GetName(): paramMap[param.GetName()]}) + sqlParams = append(sqlParams, map[string]any{param.GetName(): paramMap[param.GetName()]}) } - - body, err := json.Marshal(bodyStruct) - if err != nil { - return nil, fmt.Errorf("failed to marshal query body: %w", err) - } - res, err := esapi.EsqlQueryRequest{ - Body: bytes.NewReader(body), - Format: t.Format, - FilterPath: []string{"columns", "values"}, - Instrument: source.ElasticsearchClient().InstrumentationEnabled(), - }.Do(ctx, source.ElasticsearchClient()) - - if err != nil { - return nil, err - } - defer res.Body.Close() - - if res.IsError() { - // Try to extract error message from response - var esErr json.RawMessage - err = util.DecodeJSON(res.Body, &esErr) - if err != nil { - return nil, fmt.Errorf("elasticsearch error: status %s", res.Status()) - } - return esErr, nil - } - - var result esqlResult - err = util.DecodeJSON(res.Body, &result) - if err != nil { - return nil, fmt.Errorf("failed to decode response body: %w", err) - } - - output := t.esqlToMap(result) - - return output, nil -} - -// esqlToMap converts the esqlResult to a slice of maps. -func (t Tool) esqlToMap(result esqlResult) []map[string]any { - output := make([]map[string]any, 0, len(result.Values)) - for _, value := range result.Values { - row := make(map[string]any) - if value == nil { - output = append(output, row) - continue - } - for i, col := range result.Columns { - if i < len(value) { - row[col.Name] = value[i] - } else { - row[col.Name] = nil - } - } - output = append(output, row) - } - return output + return source.RunSQL(ctx, t.Format, query, sqlParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go index 2382483429..ec65d2842a 100644 --- a/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go +++ b/internal/tools/elasticsearch/elasticsearchesql/elasticsearchesql_test.go @@ -15,7 +15,6 @@ package elasticsearchesql import ( - "reflect" "testing" "github.com/goccy/go-yaml" @@ -106,156 +105,3 @@ func TestParseFromYamlElasticsearchEsql(t *testing.T) { }) } } - -func TestTool_esqlToMap(t1 *testing.T) { - tests := []struct { - name string - result esqlResult - want []map[string]any - }{ - { - name: "simple case with two rows", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "first_name", Type: "text"}, - {Name: "last_name", Type: "text"}, - }, - Values: [][]any{ - {"John", "Doe"}, - {"Jane", "Smith"}, - }, - }, - want: []map[string]any{ - {"first_name": "John", "last_name": "Doe"}, - {"first_name": "Jane", "last_name": "Smith"}, - }, - }, - { - name: "different data types", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "active", Type: "boolean"}, - {Name: "score", Type: "float"}, - }, - Values: [][]any{ - {1, true, 95.5}, - {2, false, 88.0}, - }, - }, - want: []map[string]any{ - {"id": 1, "active": true, "score": 95.5}, - {"id": 2, "active": false, "score": 88.0}, - }, - }, - { - name: "no rows", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{}, - }, - want: []map[string]any{}, - }, - { - name: "null values", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{ - {1, nil}, - {2, "Alice"}, - }, - }, - want: []map[string]any{ - {"id": 1, "name": nil}, - {"id": 2, "name": "Alice"}, - }, - }, - { - name: "missing values in a row", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - {Name: "age", Type: "integer"}, - }, - Values: [][]any{ - {1, "Bob"}, - {2, "Charlie", 30}, - }, - }, - want: []map[string]any{ - {"id": 1, "name": "Bob", "age": nil}, - {"id": 2, "name": "Charlie", "age": 30}, - }, - }, - { - name: "all null row", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - {Name: "name", Type: "text"}, - }, - Values: [][]any{ - nil, - }, - }, - want: []map[string]any{ - {}, - }, - }, - { - name: "empty columns", - result: esqlResult{ - Columns: []esqlColumn{}, - Values: [][]any{ - {}, - {}, - }, - }, - want: []map[string]any{ - {}, - {}, - }, - }, - { - name: "more values than columns", - result: esqlResult{ - Columns: []esqlColumn{ - {Name: "id", Type: "integer"}, - }, - Values: [][]any{ - {1, "extra"}, - }, - }, - want: []map[string]any{ - {"id": 1}, - }, - }, - { - name: "no columns but with values", - result: esqlResult{ - Columns: []esqlColumn{}, - Values: [][]any{ - {1, "data"}, - }, - }, - want: []map[string]any{ - {}, - }, - }, - } - for _, tt := range tests { - t1.Run(tt.name, func(t1 *testing.T) { - t := Tool{} - if got := t.esqlToMap(tt.result); !reflect.DeepEqual(got, tt.want) { - t1.Errorf("esqlToMap() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go index 28c8d0fb63..a6f6c01979 100644 --- a/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go +++ b/internal/tools/firebird/firebirdexecutesql/firebirdexecutesql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { FirebirdDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -106,49 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - rows, err := source.FirebirdDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - cols, err := rows.Columns() - - var out []any - if err == nil && len(cols) > 0 { - values := make([]any, len(cols)) - scanArgs := make([]any, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - for rows.Next() { - err = rows.Scan(scanArgs...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - for i, colName := range cols { - if b, ok := values[i].([]byte); ok { - vMap[colName] = string(b) - } else { - vMap[colName] = values[i] - } - } - out = append(out, vMap) - } - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating rows: %w", err) - } - - // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/firebird/firebirdsql/firebirdsql.go b/internal/tools/firebird/firebirdsql/firebirdsql.go index 9dd040dcd7..74912714a7 100644 --- a/internal/tools/firebird/firebirdsql/firebirdsql.go +++ b/internal/tools/firebird/firebirdsql/firebirdsql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { FirebirdDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -125,51 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - - rows, err := source.FirebirdDB().QueryContext(ctx, statement, namedArgs...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to get columns: %w", err) - } - - values := make([]any, len(cols)) - scanArgs := make([]any, len(values)) - for i := range values { - scanArgs[i] = &values[i] - } - - var out []any - for rows.Next() { - - err = rows.Scan(scanArgs...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - for i, col := range cols { - if b, ok := values[i].([]byte); ok { - vMap[col] = string(b) - } else { - vMap[col] = values[i] - } - } - out = append(out, vMap) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating rows: %w", err) - } - - // In most cases, DML/DDL statements like INSERT, UPDATE, CREATE, etc. might return no rows - // However, it is also possible that this was a query that was expected to return rows - // but returned none, a case that we cannot distinguish here. - return out, nil + return source.RunSQL(ctx, statement, namedArgs) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go index 51f2952177..7a017b2f98 100644 --- a/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go +++ b/internal/tools/mindsdb/mindsdbexecutesql/mindsdbexecutesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MindsDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,57 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) } - results, err := source.MindsDBPool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // MindsDB uses mysql driver - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go index c247f4d4dc..4b8ce4c045 100644 --- a/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go +++ b/internal/tools/mindsdb/mindsdbsql/mindsdbsql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MindsDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -116,59 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - - // MindsDB now supports MySQL prepared statements natively - results, err := source.MindsDBPool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // MindsDB uses mysql driver - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go index ddfbdb089e..8963544d41 100644 --- a/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go +++ b/internal/tools/mssql/mssqlexecutesql/mssqlexecutesql.go @@ -23,7 +23,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -106,47 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.MSSQLDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. - // We proceed, and results.Err() will catch actual query execution errors. - // 'out' will remain nil if cols is empty or err is not nil here. - - var out []any - if err == nil && len(cols) > 0 { - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - for results.Next() { - scanErr := results.Scan(values...) - if scanErr != nil { - return nil, fmt.Errorf("unable to parse row: %w", scanErr) - } - row := orderedmap.Row{} - for i, name := range cols { - row.Add(name, rawValues[i]) - } - out = append(out, row) - } - } - - // Check for errors from iterating over rows or from the query execution itself. - // results.Close() is handled by defer. - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqllisttables/mssqllisttables.go b/internal/tools/mssql/mssqllisttables/mssqllisttables.go index 29fbea4498..633f43dee7 100644 --- a/internal/tools/mssql/mssqllisttables/mssqllisttables.go +++ b/internal/tools/mssql/mssqllisttables/mssqllisttables.go @@ -292,6 +292,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -354,44 +355,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para sql.Named("table_names", paramsMap["table_names"]), sql.Named("output_format", outputFormat), } - - rows, err := source.MSSQLDB().QueryContext(ctx, listTablesStatement, namedArgs...) + resp, err := source.RunSQL(ctx, listTablesStatement, namedArgs) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to fetch column names: %w", err) + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - out := []any{} - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - vMap[name] = rawValues[i] - } - out = append(out, vMap) - } - - // Check if error occurred during iteration - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return resp, err } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mssql/mssqlsql/mssqlsql.go b/internal/tools/mssql/mssqlsql/mssqlsql.go index 0e621b7417..1b97a889bf 100644 --- a/internal/tools/mssql/mssqlsql/mssqlsql.go +++ b/internal/tools/mssql/mssqlsql/mssqlsql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MSSQLDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -121,47 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para namedArgs = append(namedArgs, value) } } - - rows, err := source.MSSQLDB().QueryContext(ctx, newStatement, namedArgs...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to fetch column types: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - var out []any - for rows.Next() { - err = rows.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - vMap[name] = rawValues[i] - } - out = append(out, vMap) - } - err = rows.Close() - if err != nil { - return nil, fmt.Errorf("unable to close rows: %w", err) - } - - // Check if error occurred during iteration - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil + return source.RunSQL(ctx, newStatement, namedArgs) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go index 5198602d70..1f2a5bdee3 100644 --- a/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go +++ b/internal/tools/mysql/mysqlexecutesql/mysqlexecutesql.go @@ -22,9 +22,7 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -107,58 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.MySQLPool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - row := orderedmap.Row{} - for i, name := range cols { - val := rawValues[i] - if val == nil { - row.Add(name, nil) - continue - } - - convertedValue, err := mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - row.Add(name, convertedValue) - } - out = append(out, row) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go index 3458a6ed83..04ab5c23e0 100644 --- a/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go +++ b/internal/tools/mysql/mysqlgetqueryplan/mysqlgetqueryplan.go @@ -24,6 +24,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -45,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,30 +110,27 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) query := fmt.Sprintf("EXPLAIN FORMAT=JSON %s", sql) - results, err := source.MySQLPool().QueryContext(ctx, query) + result, err := source.RunSQL(ctx, query, nil) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer results.Close() - - var plan string - if results.Next() { - if err := results.Scan(&plan); err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - } else { + // extract and return only the query plan object + resSlice, ok := result.([]any) + if !ok || len(resSlice) == 0 { return nil, fmt.Errorf("no query plan returned") } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) + row, ok := resSlice[0].(orderedmap.Row) + if !ok || len(row.Columns) == 0 { + return nil, fmt.Errorf("no query plan returned in row") } - - var out any + plan, ok := row.Columns[0].Value.(string) + if !ok { + return nil, fmt.Errorf("unable to convert plan object to string") + } + var out map[string]any if err := json.Unmarshal([]byte(plan), &out); err != nil { return nil, fmt.Errorf("failed to unmarshal query plan json: %w", err) } - return out, nil } diff --git a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go index 323d582d32..6124115c78 100644 --- a/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go +++ b/internal/tools/mysql/mysqllistactivequeries/mysqllistactivequeries.go @@ -24,7 +24,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" "github.com/googleapis/genai-toolbox/internal/sources/mysql" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -109,6 +108,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -146,14 +146,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) var statement string sourceKind := rawS.SourceKind() - switch sourceKind { case mysql.SourceKind: statement = listActiveQueriesStatementMySQL case cloudsqlmysql.SourceKind: statement = listActiveQueriesStatementCloudSQLMySQL default: - return nil, fmt.Errorf("unsupported source kind kind: %q", sourceKind) + return nil, fmt.Errorf("unsupported source kind: %s", cfg.Source) } // finish tool setup t := Tool{ @@ -200,57 +199,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, t.statement)) - - results, err := source.MySQLPool().QueryContext(ctx, t.statement, duration, duration, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, t.statement, []any{duration, duration, limit}) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go index a0bc1b8f66..28cce1bc54 100644 --- a/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go +++ b/internal/tools/mysql/mysqllisttablefragmentation/mysqllisttablefragmentation.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -67,6 +66,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -144,57 +144,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTableFragmentationStatement)) - - results, err := source.MySQLPool().QueryContext(ctx, listTableFragmentationStatement, table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + sliceParams := []any{table_schema, table_schema, table_name, table_name, data_free_threshold_bytes, limit} + return source.RunSQL(ctx, listTableFragmentationStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttables/mysqllisttables.go b/internal/tools/mysql/mysqllisttables/mysqllisttables.go index 66928b75fa..f8e0c1dced 100644 --- a/internal/tools/mysql/mysqllisttables/mysqllisttables.go +++ b/internal/tools/mysql/mysqllisttables/mysqllisttables.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -197,6 +196,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -259,57 +259,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if outputFormat != "simple" && outputFormat != "detailed" { return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - - results, err := source.MySQLPool().QueryContext(ctx, listTablesStatement, tableNames, outputFormat) + resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - out := []any{} - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return resp, err } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go index 522b180acd..e19e14a33d 100644 --- a/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go +++ b/internal/tools/mysql/mysqllisttablesmissinguniqueindexes/mysqllisttablesmissinguniqueindexes.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -68,6 +67,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -135,57 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, listTablesMissingUniqueIndexesStatement)) - - results, err := source.MySQLPool().QueryContext(ctx, listTablesMissingUniqueIndexesStatement, table_schema, table_schema, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listTablesMissingUniqueIndexesStatement, []any{table_schema, table_schema, limit}) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/mysql/mysqlsql/mysqlsql.go b/internal/tools/mysql/mysqlsql/mysqlsql.go index edf5f65db1..f89dde648b 100644 --- a/internal/tools/mysql/mysqlsql/mysqlsql.go +++ b/internal/tools/mysql/mysqlsql/mysqlsql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { MySQLPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -110,56 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.MySQLPool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go index 5f5c4ce05b..0073f90644 100644 --- a/internal/tools/neo4j/neo4jcypher/neo4jcypher.go +++ b/internal/tools/neo4j/neo4jcypher/neo4jcypher.go @@ -19,8 +19,6 @@ import ( "fmt" "github.com/goccy/go-yaml" - "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" - "github.com/neo4j/neo4j-go-driver/v5/neo4j" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -44,8 +42,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - Neo4jDriver() neo4j.DriverWithContext - Neo4jDatabase() string + Neo4jDatabase() string // kept to ensure neo4j source + RunQuery(context.Context, string, map[string]any, bool, bool) (any, error) } type Config struct { @@ -93,26 +91,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } paramsMap := params.AsMap() - - config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) - results, err := neo4j.ExecuteQuery[*neo4j.EagerResult](ctx, source.Neo4jDriver(), t.Statement, paramsMap, - neo4j.EagerResultTransformer, config) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - var out []any - keys := results.Keys - records := results.Records - for _, record := range records { - vMap := make(map[string]any) - for col, value := range record.Values { - vMap[keys[col]] = helpers.ConvertValue(value) - } - out = append(out, vMap) - } - - return out, nil + return source.RunQuery(ctx, t.Statement, paramsMap, false, false) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go index 0bf2b8f34e..2ca95dc822 100644 --- a/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go +++ b/internal/tools/neo4j/neo4jexecutecypher/neo4jexecutecypher.go @@ -21,10 +21,7 @@ import ( "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher/classifier" - "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jschema/helpers" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "github.com/neo4j/neo4j-go-driver/v5/neo4j" ) const kind string = "neo4j-execute-cypher" @@ -44,8 +41,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - Neo4jDriver() neo4j.DriverWithContext - Neo4jDatabase() string + Neo4jDatabase() string // kept to ensure neo4j source + RunQuery(context.Context, string, map[string]any, bool, bool) (any, error) } type Config struct { @@ -80,7 +77,6 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) t := Tool{ Config: cfg, Parameters: params, - classifier: classifier.NewQueryClassifier(), manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, mcpManifest: mcpManifest, } @@ -93,7 +89,6 @@ var _ tools.Tool = Tool{} type Tool struct { Config Parameters parameters.Parameters `yaml:"parameters"` - classifier *classifier.QueryClassifier manifest tools.Manifest mcpManifest tools.McpManifest } @@ -119,59 +114,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"]) } - // validate the cypher query before executing - cf := t.classifier.Classify(cypherStr) - if cf.Error != nil { - return nil, cf.Error - } - - if cf.Type == classifier.WriteQuery && t.ReadOnly { - return nil, fmt.Errorf("this tool is read-only and cannot execute write queries") - } - - if dryRun { - // Add EXPLAIN to the beginning of the query to validate it without executing - cypherStr = "EXPLAIN " + cypherStr - } - - config := neo4j.ExecuteQueryWithDatabase(source.Neo4jDatabase()) - results, err := neo4j.ExecuteQuery(ctx, source.Neo4jDriver(), cypherStr, nil, - neo4j.EagerResultTransformer, config) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - // If dry run, return the summary information only - if dryRun { - summary := results.Summary - plan := summary.Plan() - execPlan := map[string]any{ - "queryType": cf.Type.String(), - "statementType": summary.StatementType(), - "operator": plan.Operator(), - "arguments": plan.Arguments(), - "identifiers": plan.Identifiers(), - "childrenCount": len(plan.Children()), - } - if len(plan.Children()) > 0 { - execPlan["children"] = addPlanChildren(plan) - } - return []map[string]any{execPlan}, nil - } - - var out []any - keys := results.Keys - records := results.Records - - for _, record := range records { - vMap := make(map[string]any) - for col, value := range record.Values { - vMap[keys[col]] = helpers.ConvertValue(value) - } - out = append(out, vMap) - } - - return out, nil + return source.RunQuery(ctx, cypherStr, nil, t.ReadOnly, dryRun) } func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (parameters.ParamValues, error) { @@ -194,24 +137,6 @@ func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (boo return false, nil } -// Recursive function to add plan children -func addPlanChildren(p neo4j.Plan) []map[string]any { - var children []map[string]any - for _, child := range p.Children() { - childMap := map[string]any{ - "operator": child.Operator(), - "arguments": child.Arguments(), - "identifiers": child.Identifiers(), - "children_count": len(child.Children()), - } - if len(child.Children()) > 0 { - childMap["children"] = addPlanChildren(child) - } - children = append(children, childMap) - } - return children -} - func (t Tool) ToConfig() tools.ToolConfig { return t.Config } diff --git a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go index fa8d7a96a9..46aa9cc998 100644 --- a/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go +++ b/internal/tools/oceanbase/oceanbaseexecutesql/oceanbaseexecutesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -36,6 +35,7 @@ func init() { type compatibleSource interface { OceanBasePool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -99,58 +99,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("unable to get cast %s", sliceParams[0]) } - - results, err := source.OceanBasePool().QueryContext(ctx, sqlStr) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // oceanbase uses mysql driver - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sqlStr, nil) } // ParseParams parses the input parameters for the tool. diff --git a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go index 10a4dc17de..db273642a6 100644 --- a/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go +++ b/internal/tools/oceanbase/oceanbasesql/oceanbasesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -36,6 +35,7 @@ func init() { type compatibleSource interface { OceanBasePool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -109,59 +109,8 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - sliceParams := newParams.AsSlice() - results, err := source.OceanBasePool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // oceanbase uses mysql driver - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } // ParseParams parses the input parameters for the tool. diff --git a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go index 447f9362e9..211f7791d0 100644 --- a/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go +++ b/internal/tools/oracle/oracleexecutesql/oracleexecutesql.go @@ -5,9 +5,7 @@ package oracleexecutesql import ( "context" "database/sql" - "encoding/json" "fmt" - "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" @@ -34,6 +32,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { OracleDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -95,107 +94,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam) - - results, err := source.OracleDB().QueryContext(ctx, sqlParam) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - // If Columns() errors, it might be a DDL/DML without an OUTPUT clause. - // We proceed, and results.Err() will catch actual query execution errors. - // 'out' will remain nil if cols is empty or err is not nil here. - cols, _ := results.Columns() - - // Get Column types - colTypes, err := results.ColumnTypes() - if err != nil { - if err := results.Err(); err != nil { - return nil, fmt.Errorf("query execution error: %w", err) - } - return []any{}, nil - } - - var out []any - for results.Next() { - // Create slice to hold values - values := make([]any, len(cols)) - for i, colType := range colTypes { - // Based on the database type, we prepare a pointer to a Go type. - switch strings.ToUpper(colType.DatabaseTypeName()) { - case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE": - if _, scale, ok := colType.DecimalSize(); ok && scale == 0 { - // Scale is 0, treat as an integer. - values[i] = new(sql.NullInt64) - } else { - // Scale is non-zero or unknown, treat as a float. - values[i] = new(sql.NullFloat64) - } - case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE": - values[i] = new(sql.NullTime) - case "JSON": - values[i] = new(sql.RawBytes) - default: - values[i] = new(sql.NullString) - } - } - - if err := results.Scan(values...); err != nil { - return nil, fmt.Errorf("unable to scan row: %w", err) - } - - vMap := make(map[string]any) - for i, col := range cols { - receiver := values[i] - - // Dereference the pointer and check for validity (not NULL). - switch v := receiver.(type) { - case *sql.NullInt64: - if v.Valid { - vMap[col] = v.Int64 - } else { - vMap[col] = nil - } - case *sql.NullFloat64: - if v.Valid { - vMap[col] = v.Float64 - } else { - vMap[col] = nil - } - case *sql.NullString: - if v.Valid { - vMap[col] = v.String - } else { - vMap[col] = nil - } - case *sql.NullTime: - if v.Valid { - vMap[col] = v.Time - } else { - vMap[col] = nil - } - case *sql.RawBytes: - if *v != nil { - var unmarshaledData any - if err := json.Unmarshal(*v, &unmarshaledData); err != nil { - return nil, fmt.Errorf("unable to unmarshal json data for column %s", col) - } - vMap[col] = unmarshaledData - } else { - vMap[col] = nil - } - default: - return nil, fmt.Errorf("unexpected receiver type: %T", v) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sqlParam, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/oracle/oraclesql/oraclesql.go b/internal/tools/oracle/oraclesql/oraclesql.go index 1ba87b47bd..d6e536a637 100644 --- a/internal/tools/oracle/oraclesql/oraclesql.go +++ b/internal/tools/oracle/oraclesql/oraclesql.go @@ -5,9 +5,7 @@ package oraclesql import ( "context" "database/sql" - "encoding/json" "fmt" - "strings" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" @@ -33,6 +31,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { OracleDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -103,99 +102,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para fmt.Printf("[%d]=%T ", i, p) } fmt.Printf("\n") - - rows, err := source.OracleDB().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - cols, _ := rows.Columns() - - // Get Column types - colTypes, err := rows.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for rows.Next() { - values := make([]any, len(cols)) - for i, colType := range colTypes { - switch strings.ToUpper(colType.DatabaseTypeName()) { - case "NUMBER", "FLOAT", "BINARY_FLOAT", "BINARY_DOUBLE": - if _, scale, ok := colType.DecimalSize(); ok && scale == 0 { - // Scale is 0, treat it as an integer. - values[i] = new(sql.NullInt64) - } else { - // Scale is non-zero or unknown, treat - // it as a float. - values[i] = new(sql.NullFloat64) - } - case "DATE", "TIMESTAMP", "TIMESTAMP WITH TIME ZONE", "TIMESTAMP WITH LOCAL TIME ZONE": - values[i] = new(sql.NullTime) - case "JSON": - values[i] = new(sql.RawBytes) - default: - values[i] = new(sql.NullString) - } - } - - if err := rows.Scan(values...); err != nil { - return nil, fmt.Errorf("unable to scan row: %w", err) - } - - vMap := make(map[string]any) - for i, col := range cols { - receiver := values[i] - - switch v := receiver.(type) { - case *sql.NullInt64: - if v.Valid { - vMap[col] = v.Int64 - } else { - vMap[col] = nil - } - case *sql.NullFloat64: - if v.Valid { - vMap[col] = v.Float64 - } else { - vMap[col] = nil - } - case *sql.NullString: - if v.Valid { - vMap[col] = v.String - } else { - vMap[col] = nil - } - case *sql.NullTime: - if v.Valid { - vMap[col] = v.Time - } else { - vMap[col] = nil - } - case *sql.RawBytes: - if *v != nil { - var unmarshaledData any - if err := json.Unmarshal(*v, &unmarshaledData); err != nil { - return nil, fmt.Errorf("unable to unmarshal json data for column %s", col) - } - vMap[col] = unmarshaledData - } else { - vMap[col] = nil - } - default: - return nil, fmt.Errorf("unexpected receiver type: %T", v) - } - } - out = append(out, vMap) - } - - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during query execution or row processing: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go index 4e8a0a29ce..6668ab0795 100644 --- a/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go +++ b/internal/tools/postgres/postgresdatabaseoverview/postgresdatabaseoverview.go @@ -56,7 +56,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T } type compatibleSource interface { - PostgresPool() *pgxpool.Pool + PostgresPool() *pgxpool.Pool // keep this so that sources are postgres compatible + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -121,34 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, databaseOverviewStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, databaseOverviewStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go index 73afd2a6ee..21f7baf6af 100644 --- a/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go +++ b/internal/tools/postgres/postgresexecutesql/postgresexecutesql.go @@ -22,7 +22,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" "github.com/jackc/pgx/v5/pgxpool" ) @@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -106,32 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - results, err := source.PostgresPool().Query(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - row := orderedmap.Row{} - for i, f := range fields { - row.Add(f.Name, v[i]) - } - out = append(out, row) - } - - if err := results.Err(); err != nil { - return err.Error(), fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go index f96654fbc6..8ace1b9d88 100644 --- a/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go +++ b/internal/tools/postgres/postgresgetcolumncardinality/postgresgetcolumncardinality.go @@ -62,6 +62,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -133,33 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, getColumnCardinality, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - if err := results.Err(); err != nil { - return err.Error(), fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, getColumnCardinality, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go index 6ad5bff569..97f27446d1 100644 --- a/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go +++ b/internal/tools/postgres/postgreslistactivequeries/postgreslistactivequeries.go @@ -66,6 +66,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -130,33 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listActiveQueriesStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listActiveQueriesStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go index 1440509cbb..909b2d3542 100644 --- a/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go +++ b/internal/tools/postgres/postgreslistavailableextensions/postgreslistavailableextensions.go @@ -53,6 +53,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -101,33 +102,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, f := range fields { - vMap[f.Name] = v[i] - } - out = append(out, vMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go index 27cc16c1ed..393691c049 100644 --- a/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go +++ b/internal/tools/postgres/postgreslistdatabasestats/postgreslistdatabasestats.go @@ -110,6 +110,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -199,33 +200,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listDatabaseStats, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listDatabaseStats, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go index 0f85a0e46c..6d9464f4bc 100644 --- a/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go +++ b/internal/tools/postgres/postgreslistindexes/postgreslistindexes.go @@ -89,6 +89,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -161,33 +162,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listIndexesStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listIndexesStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go index effa306f46..9894e2ecd1 100644 --- a/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go +++ b/internal/tools/postgres/postgreslistinstalledextensions/postgreslistinstalledextensions.go @@ -64,6 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -112,33 +113,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, err } - - results, err := source.PostgresPool().Query(ctx, listAvailableExtensionsQuery) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, f := range fields { - vMap[f.Name] = v[i] - } - out = append(out, vMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listAvailableExtensionsQuery, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go index 881962e2be..6105801533 100644 --- a/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go +++ b/internal/tools/postgres/postgreslistlocks/postgreslistlocks.go @@ -64,6 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -132,28 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listLocks, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - return out, nil + return source.RunSQL(ctx, listLocks, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go index 05fccc3d6e..5a4dd17ba1 100644 --- a/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go +++ b/internal/tools/postgres/postgreslistpgsettings/postgreslistpgsettings.go @@ -62,6 +62,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -125,34 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listPgSettingsStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listPgSettingsStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go index 9b1d48fdea..c0b154d4c1 100644 --- a/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go +++ b/internal/tools/postgres/postgreslistpublicationtables/postgreslistpublicationtables.go @@ -73,6 +73,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -139,33 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listPublicationTablesStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - if err := results.Err(); err != nil { - return err.Error(), fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listPublicationTablesStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go index e2a26e496b..2a4f808779 100644 --- a/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go +++ b/internal/tools/postgres/postgreslistquerystats/postgreslistquerystats.go @@ -63,6 +63,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -132,33 +133,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listQueryStats, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - if err := results.Err(); err != nil { - return err.Error(), fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listQueryStats, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistroles/postgreslistroles.go b/internal/tools/postgres/postgreslistroles/postgreslistroles.go index 160aebb31a..975f073199 100644 --- a/internal/tools/postgres/postgreslistroles/postgreslistroles.go +++ b/internal/tools/postgres/postgreslistroles/postgreslistroles.go @@ -85,6 +85,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -154,34 +155,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listRolesStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listRolesStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go index 729a4af1b4..b40e763bb9 100644 --- a/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go +++ b/internal/tools/postgres/postgreslistschemas/postgreslistschemas.go @@ -97,6 +97,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -162,34 +163,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listSchemasStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listSchemasStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go index a8877ab6f7..bfdf53d143 100644 --- a/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go +++ b/internal/tools/postgres/postgreslistsequences/postgreslistsequences.go @@ -63,6 +63,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -133,33 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listSequencesStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listSequencesStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttables/postgreslisttables.go b/internal/tools/postgres/postgreslisttables/postgreslisttables.go index 264983edb6..b5d7bb7776 100644 --- a/internal/tools/postgres/postgreslisttables/postgreslisttables.go +++ b/internal/tools/postgres/postgreslisttables/postgreslisttables.go @@ -121,6 +121,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -182,33 +183,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if outputFormat != "simple" && outputFormat != "detailed" { return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) } - - results, err := source.PostgresPool().Query(ctx, listTablesStatement, tableNames, outputFormat) + resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) + return nil, err } - defer results.Close() - - fields := results.FieldDescriptions() - out := []map[string]any{} - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("error reading query results: %w", err) - } - - return out, nil + return resp, err } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go index 8e2d0e700d..3271a76bdc 100644 --- a/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go +++ b/internal/tools/postgres/postgreslisttablespaces/postgreslisttablespaces.go @@ -69,6 +69,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -141,32 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("invalid 'limit' parameter; expected an integer") } - results, err := source.PostgresPool().Query(ctx, listTableSpacesStatement, tablespaceName, limit) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listTableSpacesStatement, []any{tablespaceName, limit}) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go index 69a953e654..643775319c 100644 --- a/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go +++ b/internal/tools/postgres/postgreslisttablestats/postgreslisttablestats.go @@ -90,6 +90,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -172,28 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, listTableStats, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - return out, nil + return source.RunSQL(ctx, listTableStats, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go index 8fc4944f73..9a14b196a3 100644 --- a/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go +++ b/internal/tools/postgres/postgreslisttriggers/postgreslisttriggers.go @@ -89,6 +89,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -159,34 +160,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listTriggersStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listTriggersStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslistviews/postgreslistviews.go b/internal/tools/postgres/postgreslistviews/postgreslistviews.go index d0aa2438d1..53f5b8fcbf 100644 --- a/internal/tools/postgres/postgreslistviews/postgreslistviews.go +++ b/internal/tools/postgres/postgreslistviews/postgreslistviews.go @@ -64,6 +64,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -129,34 +130,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, listViewsStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, listViewsStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go index 1b2434679d..ad2e3869a1 100644 --- a/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go +++ b/internal/tools/postgres/postgreslongrunningtransactions/postgreslongrunningtransactions.go @@ -71,6 +71,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -141,29 +142,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, longRunningTransactions, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - return out, nil + return source.RunSQL(ctx, longRunningTransactions, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go index 4280f1a0a3..d12c805d17 100644 --- a/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go +++ b/internal/tools/postgres/postgresreplicationstats/postgresreplicationstats.go @@ -61,6 +61,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -128,29 +129,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - - results, err := source.PostgresPool().Query(ctx, replicationStats, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - fields := results.FieldDescriptions() - var out []map[string]any - - for results.Next() { - values, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - rowMap := make(map[string]any) - for i, field := range fields { - rowMap[string(field.Name)] = values[i] - } - out = append(out, rowMap) - } - - return out, nil + return source.RunSQL(ctx, replicationStats, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/postgres/postgressql/postgressql.go b/internal/tools/postgres/postgressql/postgressql.go index 1de22a5a82..57a4d81c54 100644 --- a/internal/tools/postgres/postgressql/postgressql.go +++ b/internal/tools/postgres/postgressql/postgressql.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { PostgresPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,32 +109,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := source.PostgresPool().Query(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, f := range fields { - vMap[f.Name] = v[i] - } - out = append(out, vMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/redis/redis.go b/internal/tools/redis/redis.go index 6995163a6a..c9beba3bc7 100644 --- a/internal/tools/redis/redis.go +++ b/internal/tools/redis/redis.go @@ -22,8 +22,6 @@ import ( redissrc "github.com/googleapis/genai-toolbox/internal/sources/redis" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - jsoniter "github.com/json-iterator/go" - "github.com/redis/go-redis/v9" ) const kind string = "redis" @@ -44,6 +42,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { RedisClient() redissrc.RedisClient + RunCommand(context.Context, [][]any) (any, error) } type Config struct { @@ -94,44 +93,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error replacing commands' parameters: %s", err) } - - // Execute commands - responses := make([]*redis.Cmd, len(cmds)) - for i, cmd := range cmds { - responses[i] = source.RedisClient().Do(ctx, cmd...) - } - // Parse responses - out := make([]any, len(t.Commands)) - for i, resp := range responses { - if err := resp.Err(); err != nil { - // Add error from each command to `errSum` - errString := fmt.Sprintf("error from executing command at index %d: %s", i, err) - out[i] = errString - continue - } - val, err := resp.Result() - if err != nil { - return nil, fmt.Errorf("error getting result: %s", err) - } - // If result is a map, convert map[any]any to map[string]any - // Because the Go's built-in json/encoding marshalling doesn't support - // map[any]any as an input - var strMap map[string]any - var json = jsoniter.ConfigCompatibleWithStandardLibrary - mapStr, err := json.Marshal(val) - if err != nil { - return nil, fmt.Errorf("error marshalling result: %s", err) - } - err = json.Unmarshal(mapStr, &strMap) - if err != nil { - // result is not a map - out[i] = val - continue - } - out[i] = strMap - } - - return out, nil + return source.RunCommand(ctx, cmds) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go index 7ab352b195..2b9c484c26 100644 --- a/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go +++ b/internal/tools/singlestore/singlestoreexecutesql/singlestoreexecutesql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -45,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SingleStorePool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } // Config represents the configuration for the singlestore-execute-sql tool. @@ -115,57 +115,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql) - - results, err := source.SingleStorePool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/singlestore/singlestoresql/singlestoresql.go b/internal/tools/singlestore/singlestoresql/singlestoresql.go index 55adfe2dbf..1c6f6a7e15 100644 --- a/internal/tools/singlestore/singlestoresql/singlestoresql.go +++ b/internal/tools/singlestore/singlestoresql/singlestoresql.go @@ -22,7 +22,6 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" - "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SingleStorePool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } // Config defines the configuration for a SingleStore SQL tool. @@ -143,56 +143,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.SingleStorePool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val) - if err != nil { - return nil, fmt.Errorf("errors encountered when converting values: %w", err) - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go index f0c4ce2460..68bf751348 100644 --- a/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go +++ b/internal/tools/spanner/spannerexecutesql/spannerexecutesql.go @@ -23,9 +23,7 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/iterator" ) const kind string = "spanner-execute-sql" @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SpannerClient() *spanner.Client DatabaseDialect() string + RunSQL(context.Context, bool, string, map[string]any) (any, error) } type Config struct { @@ -91,30 +90,6 @@ type Tool struct { mcpManifest tools.McpManifest } -// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. -func processRows(iter *spanner.RowIterator) ([]any, error) { - var out []any - defer iter.Stop() - - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - rowMap := orderedmap.Row{} - cols := row.ColumnNames() - for i, c := range cols { - rowMap.Add(c, row.ColumnValue(i)) - } - out = append(out, rowMap) - } - return out, nil -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -133,31 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - var results []any - var opErr error - stmt := spanner.Statement{SQL: sql} - - if t.ReadOnly { - iter := source.SpannerClient().Single().Query(ctx, stmt) - results, opErr = processRows(iter) - } else { - _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - var err error - iter := txn.Query(ctx, stmt) - results, err = processRows(iter) - if err != nil { - return err - } - return nil - }) - } - - if opErr != nil { - return nil, fmt.Errorf("unable to execute query: %w", opErr) - } - - return results, nil + return source.RunSQL(ctx, t.ReadOnly, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go index b9e94408e2..ca5a7572bd 100644 --- a/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go +++ b/internal/tools/spanner/spannerlistgraphs/spannerlistgraphs.go @@ -16,7 +16,6 @@ package spannerlistgraphs import ( "context" - "encoding/json" "fmt" "strings" @@ -25,7 +24,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/iterator" ) const kind string = "spanner-list-graphs" @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SpannerClient() *spanner.Client DatabaseDialect() string + RunSQL(context.Context, bool, string, map[string]any) (any, error) } type Config struct { @@ -105,39 +104,6 @@ type Tool struct { mcpManifest tools.McpManifest } -// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. -func processRows(iter *spanner.RowIterator) ([]any, error) { - var out []any - defer iter.Stop() - - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - cols := row.ColumnNames() - for i, c := range cols { - if c == "object_details" { - jsonString := row.ColumnValue(i).AsInterface().(string) - var details map[string]interface{} - if err := json.Unmarshal([]byte(jsonString), &details); err != nil { - return nil, fmt.Errorf("unable to unmarshal JSON: %w", err) - } - vMap[c] = details - } else { - vMap[c] = row.ColumnValue(i) - } - } - out = append(out, vMap) - } - return out, nil -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -161,20 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para "graph_names": graphNames, "output_format": outputFormat, } - - stmt := spanner.Statement{ - SQL: googleSQLStatement, - Params: stmtParams, - } - - // Execute the query (read-only) - iter := source.SpannerClient().Single().Query(ctx, stmt) - results, err := processRows(iter) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return results, nil + return source.RunSQL(ctx, true, googleSQLStatement, stmtParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannerlisttables/spannerlisttables.go b/internal/tools/spanner/spannerlisttables/spannerlisttables.go index bd41479fed..03230358f9 100644 --- a/internal/tools/spanner/spannerlisttables/spannerlisttables.go +++ b/internal/tools/spanner/spannerlisttables/spannerlisttables.go @@ -16,7 +16,6 @@ package spannerlisttables import ( "context" - "encoding/json" "fmt" "strings" @@ -25,7 +24,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/iterator" ) const kind string = "spanner-list-tables" @@ -47,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SpannerClient() *spanner.Client DatabaseDialect() string + RunSQL(context.Context, bool, string, map[string]any) (any, error) } type Config struct { @@ -105,41 +104,8 @@ type Tool struct { mcpManifest tools.McpManifest } -// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. -func processRows(iter *spanner.RowIterator) ([]any, error) { - out := []any{} - defer iter.Stop() - - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - cols := row.ColumnNames() - for i, c := range cols { - if c == "object_details" { - jsonString := row.ColumnValue(i).AsInterface().(string) - var details map[string]interface{} - if err := json.Unmarshal([]byte(jsonString), &details); err != nil { - return nil, fmt.Errorf("unable to unmarshal JSON: %w", err) - } - vMap[c] = details - } else { - vMap[c] = row.ColumnValue(i) - } - } - out = append(out, vMap) - } - return out, nil -} - -func (t Tool) getStatement(source compatibleSource) string { - switch strings.ToLower(source.DatabaseDialect()) { +func getStatement(dialect string) string { + switch strings.ToLower(dialect) { case "postgresql": return postgresqlStatement case "googlesql": @@ -159,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para paramsMap := params.AsMap() // Get the appropriate SQL statement based on dialect - statement := t.getStatement(source) + statement := getStatement(source.DatabaseDialect()) // Prepare parameters based on dialect var stmtParams map[string]interface{} @@ -177,7 +143,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para "p1": tableNames, "p2": outputFormat, } - case "googlesql": // GoogleSQL uses named parameters (@table_names, @output_format) stmtParams = map[string]interface{}{ @@ -188,19 +153,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unsupported dialect: %s", source.DatabaseDialect()) } - stmt := spanner.Statement{ - SQL: statement, - Params: stmtParams, - } - - // Execute the query (read-only) - iter := source.SpannerClient().Single().Query(ctx, stmt) - results, err := processRows(iter) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return results, nil + return source.RunSQL(ctx, true, statement, stmtParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/spanner/spannersql/spannersql.go b/internal/tools/spanner/spannersql/spannersql.go index d1b7c1ab54..eea2d89667 100644 --- a/internal/tools/spanner/spannersql/spannersql.go +++ b/internal/tools/spanner/spannersql/spannersql.go @@ -24,7 +24,6 @@ import ( "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util/parameters" - "google.golang.org/api/iterator" ) const kind string = "spanner-sql" @@ -46,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SpannerClient() *spanner.Client DatabaseDialect() string + RunSQL(context.Context, bool, string, map[string]any) (any, error) } type Config struct { @@ -106,30 +106,6 @@ func getMapParams(params parameters.ParamValues, dialect string) (map[string]int } } -// processRows iterates over the spanner.RowIterator and converts each row to a map[string]any. -func processRows(iter *spanner.RowIterator) ([]any, error) { - var out []any - defer iter.Stop() - - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - - vMap := make(map[string]any) - cols := row.ColumnNames() - for i, c := range cols { - vMap[c] = row.ColumnValue(i) - } - out = append(out, vMap) - } - return out, nil -} - func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind) if err != nil { @@ -174,33 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("fail to get map params: %w", err) } - - var results []any - var opErr error - stmt := spanner.Statement{ - SQL: newStatement, - Params: mapParams, - } - - if t.ReadOnly { - iter := source.SpannerClient().Single().Query(ctx, stmt) - results, opErr = processRows(iter) - } else { - _, opErr = source.SpannerClient().ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { - iter := txn.Query(ctx, stmt) - results, err = processRows(iter) - if err != nil { - return err - } - return nil - }) - } - - if opErr != nil { - return nil, fmt.Errorf("unable to execute client: %w", opErr) - } - - return results, nil + return source.RunSQL(ctx, t.ReadOnly, newStatement, mapParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go index e2c03a224a..f8a7e78527 100644 --- a/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go +++ b/internal/tools/sqlite/sqliteexecutesql/sqliteexecutesql.go @@ -17,14 +17,12 @@ package sqliteexecutesql import ( "context" "database/sql" - "encoding/json" "fmt" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" "github.com/googleapis/genai-toolbox/internal/util" - "github.com/googleapis/genai-toolbox/internal/util/orderedmap" "github.com/googleapis/genai-toolbox/internal/util/parameters" ) @@ -46,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SQLiteDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,65 +107,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.SQLiteDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // The sqlite driver does not support ColumnTypes, so we can't get the - // underlying database type of the columns. We'll have to rely on the - // generic `any` type and then handle the JSON data separately. - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - row := orderedmap.Row{} - for i, name := range cols { - val := rawValues[i] - if val == nil { - row.Add(name, nil) - continue - } - - // Handle JSON data - if jsonString, ok := val.(string); ok { - var unmarshaledData any - if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil { - row.Add(name, unmarshaledData) - continue - } - } - row.Add(name, val) - } - out = append(out, row) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - if len(out) == 0 { - return nil, nil - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/sqlite/sqlitesql/sqlitesql.go b/internal/tools/sqlite/sqlitesql/sqlitesql.go index e715252dc4..d61038a94b 100644 --- a/internal/tools/sqlite/sqlitesql/sqlitesql.go +++ b/internal/tools/sqlite/sqlitesql/sqlitesql.go @@ -17,7 +17,6 @@ package sqlitesql import ( "context" "database/sql" - "encoding/json" "fmt" yaml "github.com/goccy/go-yaml" @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { SQLiteDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,64 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("unable to extract standard params %w", err) } - - // Execute the SQL query with parameters - rows, err := source.SQLiteDB().QueryContext(ctx, newStatement, newParams.AsSlice()...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer rows.Close() - - // Get column names - cols, err := rows.Columns() - if err != nil { - return nil, fmt.Errorf("unable to get column names: %w", err) - } - - // The sqlite driver does not support ColumnTypes, so we can't get the - // underlying database type of the columns. We'll have to rely on the - // generic `any` type and then handle the JSON data separately. - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - // Prepare the result slice - var out []any - for rows.Next() { - if err := rows.Scan(values...); err != nil { - return nil, fmt.Errorf("unable to scan row: %w", err) - } - - // Create a map for this row - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - // Handle nil values - if val == nil { - vMap[name] = nil - continue - } - // Handle JSON data - if jsonString, ok := val.(string); ok { - var unmarshaledData any - if json.Unmarshal([]byte(jsonString), &unmarshaledData) == nil { - vMap[name] = unmarshaledData - continue - } - } - // Store the value in the map - vMap[name] = val - } - out = append(out, vMap) - } - - if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating rows: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, newParams.AsSlice()) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go index b452de841d..5c2bf22b49 100644 --- a/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go +++ b/internal/tools/tidb/tidbexecutesql/tidbexecutesql.go @@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -105,61 +106,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("error getting logger: %s", err) } logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) - - results, err := source.TiDBPool().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR" - // we'll need to cast it back to string - switch colTypes[i].DatabaseTypeName() { - case "TEXT", "VARCHAR", "NVARCHAR": - vMap[name] = string(val.([]byte)) - default: - vMap[name] = val - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/tidb/tidbsql/tidbsql.go b/internal/tools/tidb/tidbsql/tidbsql.go index f35d0a61db..ab0968de67 100644 --- a/internal/tools/tidb/tidbsql/tidbsql.go +++ b/internal/tools/tidb/tidbsql/tidbsql.go @@ -17,7 +17,6 @@ package tidbsql import ( "context" "database/sql" - "encoding/json" "fmt" yaml "github.com/goccy/go-yaml" @@ -44,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -110,68 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para } sliceParams := newParams.AsSlice() - results, err := source.TiDBPool().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve rows column name: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - defer results.Close() - - colTypes, err := results.ColumnTypes() - if err != nil { - return nil, fmt.Errorf("unable to get column types: %w", err) - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR" - // we'll need to cast it back to string - switch colTypes[i].DatabaseTypeName() { - case "JSON": - // unmarshal JSON data before storing to prevent double marshaling - var unmarshaledData any - err := json.Unmarshal(val.([]byte), &unmarshaledData) - if err != nil { - return nil, fmt.Errorf("unable to unmarshal json data %s", val) - } - vMap[name] = unmarshaledData - case "TEXT", "VARCHAR", "NVARCHAR": - vMap[name] = string(val.([]byte)) - default: - vMap[name] = val - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go index f9f396bd03..6a477a5e31 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TrinoDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -97,54 +98,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if !ok { return nil, fmt.Errorf("unable to cast sql parameter: %v", sliceParams[0]) } - - results, err := source.TrinoDB().QueryContext(ctx, sql) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve column names: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // Convert byte arrays to strings for text fields - if b, ok := val.([]byte); ok { - vMap[name] = string(b) - } else { - vMap[name] = val - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, sql, nil) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/trino/trinosql/trinosql.go b/internal/tools/trino/trinosql/trinosql.go index 7dd06d505c..24d9a9195b 100644 --- a/internal/tools/trino/trinosql/trinosql.go +++ b/internal/tools/trino/trinosql/trinosql.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TrinoDB() *sql.DB + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -107,53 +108,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := source.TrinoDB().QueryContext(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - defer results.Close() - - cols, err := results.Columns() - if err != nil { - return nil, fmt.Errorf("unable to retrieve column names: %w", err) - } - - // create an array of values for each column, which can be re-used to scan each row - rawValues := make([]any, len(cols)) - values := make([]any, len(cols)) - for i := range rawValues { - values[i] = &rawValues[i] - } - - var out []any - for results.Next() { - err := results.Scan(values...) - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, name := range cols { - val := rawValues[i] - if val == nil { - vMap[name] = nil - continue - } - - // Convert byte arrays to strings for text fields - if b, ok := val.([]byte); ok { - vMap[name] = string(b) - } else { - vMap[name] = val - } - } - out = append(out, vMap) - } - - if err := results.Err(); err != nil { - return nil, fmt.Errorf("errors encountered during row iteration: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/internal/tools/valkey/valkey.go b/internal/tools/valkey/valkey.go index 8f9d90c264..354a26e813 100644 --- a/internal/tools/valkey/valkey.go +++ b/internal/tools/valkey/valkey.go @@ -42,6 +42,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { ValkeyClient() valkey.Client + RunCommand(context.Context, [][]string) (any, error) } type Config struct { @@ -93,38 +94,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para if err != nil { return nil, fmt.Errorf("error replacing commands' parameters: %s", err) } - - // Build commands - builtCmds := make(valkey.Commands, len(commands)) - - for i, cmd := range commands { - builtCmds[i] = source.ValkeyClient().B().Arbitrary(cmd...).Build() - } - - if len(builtCmds) == 0 { - return nil, fmt.Errorf("no valid commands were built to execute") - } - - // Execute commands - responses := source.ValkeyClient().DoMulti(ctx, builtCmds...) - - // Parse responses - out := make([]any, len(t.Commands)) - for i, resp := range responses { - if err := resp.Error(); err != nil { - // Add error from each command to `errSum` - out[i] = fmt.Sprintf("error from executing command at index %d: %s", i, err) - continue - } - val, err := resp.ToAny() - if err != nil { - out[i] = fmt.Sprintf("error parsing response: %s", err) - continue - } - out[i] = val - } - - return out, nil + return source.RunCommand(ctx, commands) } // replaceCommandsParams is a helper function to replace parameters in the commands diff --git a/internal/tools/yugabytedbsql/yugabytedbsql.go b/internal/tools/yugabytedbsql/yugabytedbsql.go index 3b774ac366..0055e106a6 100644 --- a/internal/tools/yugabytedbsql/yugabytedbsql.go +++ b/internal/tools/yugabytedbsql/yugabytedbsql.go @@ -43,6 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { YugabyteDBPool() *pgxpool.Pool + RunSQL(context.Context, string, []any) (any, error) } type Config struct { @@ -108,32 +109,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, fmt.Errorf("unable to extract standard params %w", err) } sliceParams := newParams.AsSlice() - results, err := source.YugabyteDBPool().Query(ctx, newStatement, sliceParams...) - if err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - fields := results.FieldDescriptions() - - var out []any - for results.Next() { - v, err := results.Values() - if err != nil { - return nil, fmt.Errorf("unable to parse row: %w", err) - } - vMap := make(map[string]any) - for i, f := range fields { - vMap[f.Name] = v[i] - } - out = append(out, vMap) - } - - // this will catch actual query execution errors - if err := results.Err(); err != nil { - return nil, fmt.Errorf("unable to execute query: %w", err) - } - - return out, nil + return source.RunSQL(ctx, newStatement, sliceParams) } func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { diff --git a/tests/bigquery/bigquery_integration_test.go b/tests/bigquery/bigquery_integration_test.go index f79b8b7d31..de5126cd24 100644 --- a/tests/bigquery/bigquery_integration_test.go +++ b/tests/bigquery/bigquery_integration_test.go @@ -1701,7 +1701,7 @@ func runBigQueryDataTypeTests(t *testing.T) { api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)), - want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`, + want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true}]`, isErr: false, }, { @@ -1716,7 +1716,7 @@ func runBigQueryDataTypeTests(t *testing.T) { api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)), - want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`, + want: `[{"id":1,"int_val":123,"string_val":"hello","float_val":3.14,"bool_val":true},{"id":3,"int_val":789,"string_val":"test","float_val":100.1,"bool_val":true}]`, isErr: false, }, }