Compare commits

..

8 Commits

Author SHA1 Message Date
Wenxin Du
a30f4a02d3 Merge branch 'main' into dp 2025-10-01 15:46:34 -04:00
shuzhou-gc
0e04381ed7 docs(prebuilt): Update prebuilt tools document with newly added tools for MySQL and Cloud SQL for MySQL (#1580)
## Description

---
Update the prebuilt tools documentation with added tools for MySQL &
Cloud SQL for MySQL

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #1516
2025-10-01 19:29:33 +00:00
Dr. Strangelove
1afd9a95da docs(sources/looker): Fix typos in Looker tools (#1608) (#1610)
Fix descriptions in Looker tools docs

Co-authored-by: David Szajngarten <davidszajngarten@gmail.com>
2025-10-01 15:13:38 -04:00
Wenxin Du
4f2754b58b Merge branch 'main' into dp 2025-10-01 15:06:28 -04:00
duwenxin
17d1107a3c ci: Disable Dataplex integration test temp 2025-10-01 13:57:31 -04:00
Sri Varshitha
95efdc847f chore: Update list tables test cases and cleanup test database (#1600)
## Description

---
This change updates the list tables(`postgres`, `mysql` and `mssql`)
tests with test cases for listing all tables. The test schemas are
cleaned at the beginning of the test run to ensure deterministic output
for the list_tables tool.
## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2025-10-01 09:49:24 +05:30
Mend Renovate
73a96b1b63 chore(deps): update module cloud.google.com/go/spanner to v1.86.0 (#1587)
This PR contains the following updates:

| Package | Change | Age | Confidence |
|---|---|---|---|
|
[cloud.google.com/go/spanner](https://redirect.github.com/googleapis/google-cloud-go)
| `v1.85.1` -> `v1.86.0` |
[![age](https://developer.mend.io/api/mc/badges/age/go/cloud.google.com%2fgo%2fspanner/v1.86.0?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|
[![confidence](https://developer.mend.io/api/mc/badges/confidence/go/cloud.google.com%2fgo%2fspanner/v1.85.1/v1.86.0?slim=true)](https://docs.renovatebot.com/merge-confidence/)
|

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/googleapis/genai-toolbox).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0MS4xMzEuOSIsInVwZGF0ZWRJblZlciI6IjQxLjEzMS45IiwidGFyZ2V0QnJhbmNoIjoibWFpbiIsImxhYmVscyI6W119-->

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-09-30 14:04:14 -07:00
Mend Renovate
3553bf0ccf chore(deps): update google.golang.org/genproto digest to 57b25ae (#1599)
This PR contains the following updates:

| Package | Type | Update | Change |
|---|---|---|---|
|
[google.golang.org/genproto](https://redirect.github.com/googleapis/go-genproto)
| require | digest | `9219d12` -> `57b25ae` |

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined),
Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you
are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the
rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update
again.

---

- [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check
this box

---

This PR was generated by [Mend Renovate](https://mend.io/renovate/).
View the [repository job
log](https://developer.mend.io/github/googleapis/genai-toolbox).

<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiI0MS4xMzEuOSIsInVwZGF0ZWRJblZlciI6IjQxLjEzMS45IiwidGFyZ2V0QnJhbmNoIjoibWFpbiIsImxhYmVscyI6W119-->

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2025-09-30 13:52:05 -07:00
24 changed files with 609 additions and 287 deletions

View File

@@ -174,25 +174,25 @@ steps:
bigquery \
bigquery
- id: "dataplex"
name: golang:1
waitFor: ["compile-test-binary"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "DATAPLEX_PROJECT=$PROJECT_ID"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
.ci/test_with_coverage.sh \
"Dataplex" \
dataplex \
dataplex
# - id: "dataplex"
# name: golang:1
# waitFor: ["compile-test-binary"]
# entrypoint: /bin/bash
# env:
# - "GOPATH=/gopath"
# - "DATAPLEX_PROJECT=$PROJECT_ID"
# - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
# secretEnv: ["CLIENT_ID"]
# volumes:
# - name: "go"
# path: "/gopath"
# args:
# - -c
# - |
# .ci/test_with_coverage.sh \
# "Dataplex" \
# dataplex \
# dataplex
- id: "postgres"
name: golang:1

View File

@@ -132,6 +132,10 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
* `list_tables`: Lists tables in the database.
* `get_query_plan`: Provides information about how MySQL executes a SQL
statement.
* `list_active_queries`: Lists ongoing queries.
* `list_tables_missing_unique_indexes`: Looks for tables that do not have
primary or unique key contraint.
* `list_table_fragmentation`: Displays table fragmentation in MySQL.
## Cloud SQL for MySQL Observability
@@ -407,6 +411,10 @@ details on how to connect your AI tools (IDEs) to databases via Toolbox and MCP.
* `list_tables`: Lists tables in the database.
* `get_query_plan`: Provides information about how MySQL executes a SQL
statement.
* `list_active_queries`: Lists ongoing queries.
* `list_tables_missing_unique_indexes`: Looks for tables that do not have
primary or unique key contraint.
* `list_table_fragmentation`: Displays table fragmentation in MySQL.
## OceanBase

View File

@@ -3,8 +3,7 @@ title: "looker-add-dashboard-element"
type: docs
weight: 1
description: >
"looker-add-dashboard-element" generates a Looker look in the users personal folder in
Looker
"looker-add-dashboard-element" creates a dashboard element in the given dashboard.
aliases:
- /resources/tools/looker-add-dashboard-element
---

View File

@@ -3,8 +3,7 @@ title: "looker-get-dashboards"
type: docs
weight: 1
description: >
"looker-get-dashboards" searches for saved Looks in a Looker
source.
"looker-get-dashboards" tool searches for a saved Dashboard by name or description.
aliases:
- /resources/tools/looker-get-dashboards
---

10
go.mod
View File

@@ -7,11 +7,11 @@ toolchain go1.25.1
require (
cloud.google.com/go/alloydbconn v1.15.5
cloud.google.com/go/bigquery v1.71.0
cloud.google.com/go/bigtable v1.40.0
cloud.google.com/go/bigtable v1.40.1
cloud.google.com/go/cloudsqlconn v1.18.1
cloud.google.com/go/dataplex v1.27.1
cloud.google.com/go/firestore v1.18.0
cloud.google.com/go/spanner v1.85.1
cloud.google.com/go/spanner v1.86.0
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0
@@ -52,7 +52,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
golang.org/x/oauth2 v0.31.0
google.golang.org/api v0.250.0
google.golang.org/genproto v0.0.0-20250922171735-9219d122eba9
google.golang.org/genproto v0.0.0-20250929231259-57b25ae835d4
modernc.org/sqlite v1.39.0
)
@@ -174,8 +174,8 @@ require (
golang.org/x/time v0.13.0 // indirect
golang.org/x/tools v0.36.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250908214217-97024824d090 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250908214217-97024824d090 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250922171735-9219d122eba9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9 // indirect
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.9 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect

20
go.sum
View File

@@ -139,8 +139,8 @@ cloud.google.com/go/bigquery v1.49.0/go.mod h1:Sv8hMmTFFYBlt/ftw2uN6dFdQPzBlREY9
cloud.google.com/go/bigquery v1.50.0/go.mod h1:YrleYEh2pSEbgTBZYMJ5SuSr0ML3ypjRB1zgf7pvQLU=
cloud.google.com/go/bigquery v1.71.0 h1:NvSZvXU1Hyb+YiRVKQPuQXGeZaw/0NP6M/WOrBqSx3g=
cloud.google.com/go/bigquery v1.71.0/go.mod h1:GUbRtmeCckOE85endLherHD9RsujY+gS7i++c1CqssQ=
cloud.google.com/go/bigtable v1.40.0 h1:iNeqGqkJvFdjg07Ku3F7KKfq5QZvBySisYHVsLB1RwE=
cloud.google.com/go/bigtable v1.40.0/go.mod h1:LtPzCcrAFaGRZ82Hs8xMueUeYW9Jw12AmNdUTMfDnh4=
cloud.google.com/go/bigtable v1.40.1 h1:k8HfpUOvn7sQwc6oNKqjvD/yjkwynf4qBuyKwh5cU08=
cloud.google.com/go/bigtable v1.40.1/go.mod h1:LtPzCcrAFaGRZ82Hs8xMueUeYW9Jw12AmNdUTMfDnh4=
cloud.google.com/go/billing v1.4.0/go.mod h1:g9IdKBEFlItS8bTtlrZdVLWSSdSyFUZKXNS02zKMOZY=
cloud.google.com/go/billing v1.5.0/go.mod h1:mztb1tBc3QekhjSgmpf/CV4LzWXLzCArwpLmP2Gm88s=
cloud.google.com/go/billing v1.6.0/go.mod h1:WoXzguj+BeHXPbKfNWkqVtDdzORazmCjraY+vrxcyvI=
@@ -544,8 +544,8 @@ cloud.google.com/go/shell v1.6.0/go.mod h1:oHO8QACS90luWgxP3N9iZVuEiSF84zNyLytb+
cloud.google.com/go/spanner v1.41.0/go.mod h1:MLYDBJR/dY4Wt7ZaMIQ7rXOTLjYrmxLE/5ve9vFfWos=
cloud.google.com/go/spanner v1.44.0/go.mod h1:G8XIgYdOK+Fbcpbs7p2fiprDw4CaZX63whnSMLVBxjk=
cloud.google.com/go/spanner v1.45.0/go.mod h1:FIws5LowYz8YAE1J8fOS7DJup8ff7xJeetWEo5REA2M=
cloud.google.com/go/spanner v1.85.1 h1:cJx1ZD//C2QIfFQl8hSTn4twL8amAXtnayyflRIjj40=
cloud.google.com/go/spanner v1.85.1/go.mod h1:bbwCXbM+zljwSPLZ44wZOdzcdmy89hbUGmM/r9sD0ws=
cloud.google.com/go/spanner v1.86.0 h1:jlNWusBol1Jxa9PmYGknUBzLwvD1cebuEenzqebZ9xs=
cloud.google.com/go/spanner v1.86.0/go.mod h1:bbwCXbM+zljwSPLZ44wZOdzcdmy89hbUGmM/r9sD0ws=
cloud.google.com/go/speech v1.6.0/go.mod h1:79tcr4FHCimOp56lwC01xnt/WPJZc4v3gzyT7FoBkCM=
cloud.google.com/go/speech v1.7.0/go.mod h1:KptqL+BAQIhMsj1kOP2la5DSEEerPDuOP/2mmkhHhZQ=
cloud.google.com/go/speech v1.8.0/go.mod h1:9bYIl1/tjsAnMgKGHKmBZzXKEkGgtU+MpdDPTE9f7y0=
@@ -1976,12 +1976,12 @@ google.golang.org/genproto v0.0.0-20230323212658-478b75c54725/go.mod h1:UUQDJDOl
google.golang.org/genproto v0.0.0-20230330154414-c0448cd141ea/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak=
google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOlWu4KYeJZffbWgBkS1YFobzKbLVfK69pe0Ak=
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/genproto v0.0.0-20250922171735-9219d122eba9 h1:LvZVVaPE0JSqL+ZWb6ErZfnEOKIqqFWUJE2D0fObSmc=
google.golang.org/genproto v0.0.0-20250922171735-9219d122eba9/go.mod h1:QFOrLhdAe2PsTp3vQY4quuLKTi9j3XG3r6JPPaw7MSc=
google.golang.org/genproto/googleapis/api v0.0.0-20250908214217-97024824d090 h1:d8Nakh1G+ur7+P3GcMjpRDEkoLUcLW2iU92XVqR+XMQ=
google.golang.org/genproto/googleapis/api v0.0.0-20250908214217-97024824d090/go.mod h1:U8EXRNSd8sUYyDfs/It7KVWodQr+Hf9xtxyxWudSwEw=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250908214217-97024824d090 h1:/OQuEa4YWtDt7uQWHd3q3sUMb+QOLQUg1xa8CEsRv5w=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250908214217-97024824d090/go.mod h1:GmFNa4BdJZ2a8G+wCe9Bg3wwThLrJun751XstdJt5Og=
google.golang.org/genproto v0.0.0-20250929231259-57b25ae835d4 h1:HmI33/XNQ1jVwhb5ZUgot40oiwFHa2l5ZNkQpj8VaEg=
google.golang.org/genproto v0.0.0-20250929231259-57b25ae835d4/go.mod h1:OqVwZqqGV3h7k+YCVWXoTtwC2cs55RnDEUVMMadhxrc=
google.golang.org/genproto/googleapis/api v0.0.0-20250922171735-9219d122eba9 h1:jm6v6kMRpTYKxBRrDkYAitNJegUeO1Mf3Kt80obv0gg=
google.golang.org/genproto/googleapis/api v0.0.0-20250922171735-9219d122eba9/go.mod h1:LmwNphe5Afor5V3R5BppOULHOnt2mCIf+NxMd4XiygE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9 h1:V1jCN2HBa8sySkR5vLcCSqJSTMv093Rw9EJefhQGP7M=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250922171735-9219d122eba9/go.mod h1:HSkG/KdJWusxU1F6CNrwNDjBMgisKxGnc5dAZfT0mjQ=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38=
google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM=

View File

@@ -25,7 +25,6 @@ import (
dataplexapi "cloud.google.com/go/dataplex/apiv1"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
@@ -248,24 +247,6 @@ func (s *Source) lazyInitDataplexClient(ctx context.Context, tracer trace.Tracer
}
}
func (s *Source) RetrieveBQClient(accessToken tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
bqClient := s.Client
restService := s.RestService
// Initialize new client if using user OAuth token
if s.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = s.ClientCreator(tokenStr, true)
if err != nil {
return nil, nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
return bqClient, restService, nil
}
func initBigQueryConnection(
ctx context.Context,
tracer trace.Tracer,

View File

@@ -50,7 +50,6 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -123,11 +122,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -136,16 +140,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// Invoke runs the contribution analysis.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap()
inputData, ok := paramsMap["input_data"].(string)
if !ok {
@@ -197,9 +206,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
inputDataSource,
)
bqClient, _, err := s.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
createModelQuery := bqClient.Query(createModelSQL)
@@ -280,5 +299,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -19,9 +19,7 @@ import (
"fmt"
bigqueryapi "cloud.google.com/go/bigquery"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
// DryRunQuery performs a dry run of the SQL query to validate it and get metadata.
@@ -55,46 +53,3 @@ func DryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, proj
}
return insertResponse, nil
}
func RunQuery(ctx context.Context, statement string, query *bigqueryapi.Query) (any, error) {
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing big query execute sql query: %s", statement)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}

View File

@@ -23,6 +23,7 @@ import (
"net/http"
"strings"
bigqueryapi "cloud.google.com/go/bigquery"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
@@ -52,6 +53,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryTokenSourceWithScope(ctx context.Context, scope string) (oauth2.TokenSource, error)
BigQueryProject() string
BigQueryLocation() string
@@ -149,13 +151,33 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
parameters := tools.Parameters{userQueryParameter, tableRefsParameter}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, parameters)
// Get cloud-platform token source for Gemini Data Analytics API during initialization
var bigQueryTokenSourceWithScope oauth2.TokenSource
if !s.UseClientAuthorization() {
ctx := context.Background()
ts, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
bigQueryTokenSourceWithScope = ts
}
// finish tool setup
t := Tool{
Config: cfg,
Source: s,
Parameters: parameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Project: s.BigQueryProject(),
Location: s.BigQueryLocation(),
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
Client: s.BigQueryClient(),
UseClientOAuth: s.UseClientAuthorization(),
TokenSource: bigQueryTokenSourceWithScope,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
MaxQueryResultRows: s.GetMaxQueryResultRows(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
}
return t, nil
}
@@ -164,20 +186,29 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Project string
Location string
Client *bigqueryapi.Client
TokenSource oauth2.TokenSource
manifest tools.Manifest
mcpManifest tools.McpManifest
MaxQueryResultRows int
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
var tokenStr string
var err error
// Get credentials for the API call
if s.UseClientAuthorization() {
if t.UseClientOAuth {
// Use client-side access token
if accessToken == "" {
return nil, fmt.Errorf("tool is configured for client OAuth but no token was provided in the request header: %w", tools.ErrUnauthorized)
@@ -187,15 +218,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} else {
tokenSource, err := s.BigQueryTokenSourceWithScope(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to get cloud-platform token source: %w", err)
}
// Use cloud-platform token source for Gemini Data Analytics API
if tokenSource == nil {
if t.TokenSource == nil {
return nil, fmt.Errorf("cloud-platform token source is missing")
}
token, err := tokenSource.Token()
token, err := t.TokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token from cloud-platform token source: %w", err)
}
@@ -216,18 +243,17 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
}
allowedDataset := s.BigQueryAllowedDatasets()
if len(allowedDataset) > 0 {
if len(t.AllowedDatasets) > 0 {
for _, tableRef := range tableRefs {
if !s.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
if !t.IsDatasetAllowed(tableRef.ProjectID, tableRef.DatasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", tableRef.ProjectID, tableRef.DatasetID, tableRef.TableID)
}
}
}
// Construct URL, headers, and payload
projectID := s.BigQueryProject()
location := s.BigQueryLocation()
projectID := t.Project
location := t.Location
if location == "" {
location = "us"
}
@@ -251,7 +277,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
// Call the streaming API
response, err := getStream(caURL, payload, headers, s.GetMaxQueryResultRows())
response, err := getStream(caURL, payload, headers, t.MaxQueryResultRows)
if err != nil {
return nil, fmt.Errorf("failed to get response from conversational analytics API: %w", err)
}
@@ -276,7 +302,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}
// StreamMessage represents a single message object from the streaming API response.

View File

@@ -26,7 +26,9 @@ import (
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
"github.com/googleapis/genai-toolbox/internal/util"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-execute-sql"
@@ -52,7 +54,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -121,11 +122,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -134,15 +142,22 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
@@ -153,9 +168,20 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("unable to cast dry_run parameter %s", paramsMap["dry_run"])
}
bqClient, restService, err := s.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
restService := t.RestService
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, sql, nil, nil)
@@ -163,7 +189,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
if len(s.BigQueryAllowedDatasets()) > 0 {
if len(t.AllowedDatasets) > 0 {
switch statementType {
case "CREATE_SCHEMA", "DROP_SCHEMA", "ALTER_SCHEMA":
return nil, fmt.Errorf("dataset-level operations like '%s' are not allowed when dataset restrictions are in place", statementType)
@@ -198,7 +225,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
} else if statementType != "SELECT" {
// If dry run yields no tables, fall back to the parser for non-SELECT statements
// to catch unsafe operations like EXECUTE IMMEDIATE.
parsedTables, parseErr := bqutil.TableParser(sql, t.Source.BigQueryClient().Project())
parsedTables, parseErr := bqutil.TableParser(sql, t.Client.Project())
if parseErr != nil {
// If parsing fails (e.g., EXECUTE IMMEDIATE), we cannot guarantee safety, so we must fail.
return nil, fmt.Errorf("could not parse tables from query to validate against allowed datasets: %w", parseErr)
@@ -210,7 +237,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
parts := strings.Split(tableID, ".")
if len(parts) == 3 {
projectID, datasetID := parts[0], parts[1]
if !s.IsDatasetAllowed(projectID, datasetID) {
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("query accesses dataset '%s.%s', which is not in the allowed list", projectID, datasetID)
}
}
@@ -232,7 +259,51 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
query := bqClient.Query(sql)
query.Location = bqClient.Location
return bqutil.RunQuery(ctx, sql, query)
// Log the query executed for debugging.
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error getting logger: %s", err)
}
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sql)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
var out []any
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
@@ -252,5 +323,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -53,7 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -115,10 +114,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
IsDatasetAllowed: s.IsDatasetAllowed,
AllowedDatasets: allowedDatasets,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -127,15 +134,22 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
AllowedDatasets []string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap()
historyData, ok := paramsMap["history_data"].(string)
if !ok {
@@ -173,8 +187,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
var historyDataSource string
trimmedUpperHistoryData := strings.TrimSpace(strings.ToUpper(historyData))
if strings.HasPrefix(trimmedUpperHistoryData, "SELECT") || strings.HasPrefix(trimmedUpperHistoryData, "WITH") {
if len(s.BigQueryAllowedDatasets()) > 0 {
dryRunJob, err := bqutil.DryRunQuery(ctx, s.BigQueryRestService(), s.BigQueryClient().Project(), s.BigQueryClient().Location, historyData, nil, nil)
if len(t.AllowedDatasets) > 0 {
dryRunJob, err := bqutil.DryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, historyData, nil, nil)
if err != nil {
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
}
@@ -186,7 +200,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
queryStats := dryRunJob.Statistics.Query
if queryStats != nil {
for _, tableRef := range queryStats.ReferencedTables {
if !s.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
if !t.IsDatasetAllowed(tableRef.ProjectId, tableRef.DatasetId) {
return nil, fmt.Errorf("query in history_data accesses dataset '%s.%s', which is not in the allowed list", tableRef.ProjectId, tableRef.DatasetId)
}
}
@@ -196,7 +210,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
}
historyDataSource = fmt.Sprintf("(%s)", historyData)
} else {
if len(s.BigQueryAllowedDatasets()) > 0 {
if len(t.AllowedDatasets) > 0 {
parts := strings.Split(historyData, ".")
var projectID, datasetID string
@@ -205,13 +219,13 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
projectID = parts[0]
datasetID = parts[1]
case 2: // dataset.table
projectID = s.BigQueryClient().Project()
projectID = t.Client.Project()
datasetID = parts[0]
default:
return nil, fmt.Errorf("invalid table ID format for 'history_data': %q. Expected 'dataset.table' or 'project.dataset.table'", historyData)
}
if !s.IsDatasetAllowed(projectID, datasetID) {
if !t.IsDatasetAllowed(projectID, datasetID) {
return nil, fmt.Errorf("access to dataset '%s.%s' (from table '%s') is not allowed", projectID, datasetID, historyData)
}
}
@@ -232,9 +246,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
horizon => %d%s)`,
historyDataSource, dataCol, timestampCol, horizon, idColsArg)
bqClient, _, err := s.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
// JobStatistics.QueryStatistics.StatementType
@@ -297,5 +321,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
const kind string = "bigquery-get-dataset-info"
@@ -49,7 +48,6 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -93,11 +91,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -106,16 +108,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -127,9 +133,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
bqClient, _, err := s.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -158,5 +174,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
)
const kind string = "bigquery-get-table-info"
@@ -50,7 +49,6 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -95,10 +93,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -107,13 +110,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -133,9 +140,19 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableKey)
}
bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
var err error
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -166,5 +183,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -49,7 +48,6 @@ type compatibleSource interface {
BigQueryClient() *bigqueryapi.Client
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -93,10 +91,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -105,13 +108,17 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
@@ -121,9 +128,17 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", projectKey)
}
bqClient, _, err := t.Source.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
datasetIterator := bqClient.Datasets(ctx)
datasetIterator.ProjectID = projectId
@@ -166,5 +181,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -25,7 +25,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
"github.com/googleapis/genai-toolbox/internal/tools"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
@@ -54,7 +53,6 @@ type compatibleSource interface {
UseClientAuthorization() bool
IsDatasetAllowed(projectID, datasetID string) bool
BigQueryAllowedDatasets() []string
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -134,10 +132,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
Parameters: parameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
ClientCreator: s.BigQueryClientCreator(),
Client: s.BigQueryClient(),
IsDatasetAllowed: s.IsDatasetAllowed,
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -146,16 +150,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
Client *bigqueryapi.Client
ClientCreator bigqueryds.BigqueryClientCreator
IsDatasetAllowed func(projectID, datasetID string) bool
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
mapParams := params.AsMap()
projectId, ok := mapParams[projectKey].(string)
if !ok {
@@ -167,13 +176,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
}
if !s.IsDatasetAllowed(projectId, datasetId) {
if !t.IsDatasetAllowed(projectId, datasetId) {
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
}
bqClient, _, err := s.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, _, err = t.ClientCreator(tokenStr, false)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
dsHandle := bqClient.DatasetInProject(projectId, datasetId)
@@ -217,5 +234,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}

View File

@@ -82,6 +82,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
// Get the Dataplex client using the method from the source
makeCatalogClient := s.MakeDataplexCatalogClient()
prompt := tools.NewStringParameter("prompt", "Prompt representing search intention. Do not rewrite the prompt.")
datasetIds := tools.NewArrayParameterWithDefault("datasetIds", []any{}, "Array of dataset IDs.", tools.NewStringParameter("datasetId", "The IDs of the bigquery dataset."))
projectIds := tools.NewArrayParameterWithDefault("projectIds", []any{}, "Array of project IDs.", tools.NewStringParameter("projectId", "The IDs of the bigquery project."))
@@ -96,9 +99,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, parameters)
t := Tool{
Config: cfg,
Parameters: parameters,
Source: s,
Name: cfg.Name,
Kind: kind,
Parameters: parameters,
AuthRequired: cfg.AuthRequired,
UseClientOAuth: s.UseClientAuthorization(),
MakeCatalogClient: makeCatalogClient,
ProjectID: s.BigQueryProject(),
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: parameters.Manifest(),
@@ -110,11 +117,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
type Tool struct {
Config
Parameters tools.Parameters
Source compatibleSource
manifest tools.Manifest
mcpManifest tools.McpManifest
Name string
Kind string
Parameters tools.Parameters
AuthRequired []string
UseClientOAuth bool
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
ProjectID string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
@@ -122,7 +133,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}
func constructSearchQueryHelper(predicate string, operator string, items []string) string {
@@ -195,7 +206,6 @@ func ExtractType(resourceString string) string {
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
s := t.Source
paramsMap := params.AsMap()
pageSize := int32(paramsMap["pageSize"].(int))
prompt, _ := paramsMap["prompt"].(string)
@@ -217,14 +227,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
req := &dataplexpb.SearchEntriesRequest{
Query: fmt.Sprintf("%s %s", prompt, constructSearchQuery(projectIds, datasetIds, types)),
Name: fmt.Sprintf("projects/%s/locations/global", s.BigQueryProject()),
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
PageSize: pageSize,
SemanticSearch: true,
}
catalogClient, dataplexClientCreator, _ := s.MakeDataplexCatalogClient()()
catalogClient, dataplexClientCreator, _ := t.MakeCatalogClient()
if s.UseClientAuthorization() {
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
@@ -237,7 +247,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
it := catalogClient.SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.BigQueryProject())
return nil, fmt.Errorf("failed to create search entries iterator for project %q", t.ProjectID)
}
var results []Response

View File

@@ -28,6 +28,7 @@ import (
"github.com/googleapis/genai-toolbox/internal/tools"
bqutil "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerycommon"
bigqueryrestapi "google.golang.org/api/bigquery/v2"
"google.golang.org/api/iterator"
)
const kind string = "bigquery-sql"
@@ -51,7 +52,6 @@ type compatibleSource interface {
BigQueryRestService() *bigqueryrestapi.Service
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
UseClientAuthorization() bool
RetrieveBQClient(tools.AccessToken) (*bigqueryapi.Client, *bigqueryrestapi.Service, error)
}
// validate compatible sources are still compatible
@@ -99,11 +99,20 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
// finish tool setup
t := Tool{
Config: cfg,
AllParams: allParameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
Name: cfg.Name,
Kind: kind,
AuthRequired: cfg.AuthRequired,
Parameters: cfg.Parameters,
TemplateParameters: cfg.TemplateParameters,
AllParams: allParameters,
Statement: cfg.Statement,
UseClientOAuth: s.UseClientAuthorization(),
Client: s.BigQueryClient(),
RestService: s.BigQueryRestService(),
ClientCreator: s.BigQueryClientCreator(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
@@ -112,24 +121,32 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams tools.Parameters `yaml:"allParams"`
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
UseClientOAuth bool `yaml:"useClientOAuth"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Source compatibleSource
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
Statement string
Client *bigqueryapi.Client
RestService *bigqueryrestapi.Service
ClientCreator bigqueryds.BigqueryClientCreator
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Config.Statement, paramsMap)
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
for _, p := range t.Parameters {
name := p.GetName()
value := paramsMap[name]
@@ -197,23 +214,71 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
lowLevelParams = append(lowLevelParams, lowLevelParam)
}
bqClient, restService, err := t.Source.RetrieveBQClient(accessToken)
if err != nil {
return nil, err
bqClient := t.Client
restService := t.RestService
// Initialize new client if using user OAuth token
if t.UseClientOAuth {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
bqClient, restService, err = t.ClientCreator(tokenStr, true)
if err != nil {
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
}
}
query := bqClient.Query(newStatement)
query.Parameters = highLevelParams
query.Location = bqClient.Location
_, err = bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
dryRunJob, err := bqutil.DryRunQuery(ctx, restService, bqClient.Project(), bqClient.Location, newStatement, lowLevelParams, query.ConnectionProperties)
if err != nil {
// This is a fallback check in case the switch logic was bypassed.
return nil, fmt.Errorf("final query validation failed: %w", err)
}
statementType := dryRunJob.Statistics.Query.StatementType
return bqutil.RunQuery(ctx, newStatement, query)
// This block handles SELECT statements, which return a row set.
// We iterate through the results, convert each row into a map of
// column names to values, and return the collection of rows.
it, err := query.Read(ctx)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
var out []any
for {
var row map[string]bigqueryapi.Value
err = it.Next(&row)
if err == iterator.Done {
break
}
if err != nil {
return nil, fmt.Errorf("unable to iterate through query results: %w", err)
}
vMap := make(map[string]any)
for key, value := range row {
vMap[key] = value
}
out = append(out, vMap)
}
// If the query returned any rows, return them directly.
if len(out) > 0 {
return out, nil
}
// This handles the standard case for a SELECT query that successfully
// executes but returns zero rows.
if statementType == "SELECT" {
return "The query returned 0 rows.", nil
}
// This is the fallback for a successful query that doesn't return content.
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
// However, it is also possible that this was a query that was expected to return rows
// but returned none, a case that we cannot distinguish here.
return "Query executed successfully and returned no content.", nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
@@ -233,7 +298,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization() bool {
return t.Source.UseClientAuthorization()
return t.UseClientOAuth
}
func BQTypeStringFromToolType(toolType string) (string, error) {

View File

@@ -123,6 +123,9 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
// cleanup test environment
tests.CleanupMSSQLTables(t, ctx, db)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")

View File

@@ -110,6 +110,9 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
t.Fatalf("unable to create Cloud SQL connection pool: %s", err)
}
// cleanup test environment
tests.CleanupMySQLTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")

View File

@@ -97,6 +97,9 @@ func TestMSSQLToolEndpoints(t *testing.T) {
t.Fatalf("unable to create SQL Server connection pool: %s", err)
}
// cleanup test environment
tests.CleanupMSSQLTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")

View File

@@ -87,6 +87,9 @@ func TestMySQLToolEndpoints(t *testing.T) {
t.Fatalf("unable to create MySQL connection pool: %s", err)
}
// cleanup test environment
tests.CleanupMySQLTables(t, ctx, pool)
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")

View File

@@ -137,6 +137,9 @@ func TestPostgres(t *testing.T) {
t.Fatalf("unable to create postgres connection pool: %s", err)
}
// cleanup test environment
tests.CleanupPostgresTables(t, ctx, pool);
// create table name with UUID
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -237,7 +240,24 @@ func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth strin
requestBody io.Reader
wantStatusCode int
want string
isAllTables bool
}{
{
name: "invoke list_tables all tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": ""}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
isAllTables: true,
},
{
name: "invoke list_tables all tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "simple"}`)),
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
isAllTables: true,
},
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
@@ -334,6 +354,19 @@ func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth strin
t.Fatalf("failed to unmarshal expected want string: %v", err)
}
// Checking only the default public schema where the test tables are created to avoid brittle tests.
if tc.isAllTables {
var filteredGot []any
for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
filteredGot = append(filteredGot, item)
}
}
}
got = filteredGot
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})

View File

@@ -1184,7 +1184,15 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
wantStatusCode int
want any
isSimple bool
isAllTables bool
}{
{
name: "invoke list_tables for all tables detailed output",
requestBody: bytes.NewBufferString(`{"table_names":""}`),
wantStatusCode: http.StatusOK,
want: []objectDetails{authTableWant, paramTableWant},
isAllTables: true,
},
{
name: "invoke list_tables detailed output",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)),
@@ -1293,6 +1301,23 @@ func RunMySQLListTablesTest(t *testing.T, databaseName, tableNameParam, tableNam
cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }),
}
// Checking only the current database where the test tables are created to avoid brittle tests.
if tc.isAllTables {
var filteredGot []objectDetails
if got != nil {
for _, item := range got.([]objectDetails) {
if item.SchemaName == databaseName {
filteredGot = append(filteredGot, item)
}
}
}
if len(filteredGot) == 0 {
got = nil
} else {
got = filteredGot
}
}
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
}
@@ -1858,7 +1883,24 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
requestBody string
wantStatusCode int
want string
isAllTables bool
}{
{
name: "invoke list_tables for all tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": ""}`,
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
isAllTables: true,
},
{
name: "invoke list_tables for all tables simple output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
requestBody: `{"table_names": "", "output_format": "simple"}`,
wantStatusCode: http.StatusOK,
want: fmt.Sprintf("[%s,%s]", getSimpleWant(tableNameAuth), getSimpleWant(tableNameParam)),
isAllTables: true,
},
{
name: "invoke list_tables detailed output",
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
@@ -1968,6 +2010,19 @@ func RunMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
itemMap["object_details"] = detailsMap
}
// Checking only the default dbo schema where the test tables are created to avoid brittle tests.
if tc.isAllTables {
var filteredGot []any
for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "dbo" {
filteredGot = append(filteredGot, item)
}
}
}
got = filteredGot
}
sort.SliceStable(got, func(i, j int) bool {
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
})