mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 08:58:28 -05:00
Compare commits
15 Commits
cloud-sql-
...
sample
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7e4c12915 | ||
|
|
5530b08f34 | ||
|
|
e515d9254f | ||
|
|
d661f5343f | ||
|
|
6c8460b0e5 | ||
|
|
71f360d315 | ||
|
|
33beb7187d | ||
|
|
4f46782927 | ||
|
|
bf6831fdbe | ||
|
|
bd195d2fe2 | ||
|
|
ec0d3a6eb3 | ||
|
|
81d239b053 | ||
|
|
0cd3f16f87 | ||
|
|
aa3972470f | ||
|
|
36d79ef147 |
@@ -531,24 +531,6 @@ steps:
|
||||
utility \
|
||||
utility/alloydbwaitforoperation
|
||||
|
||||
- id: "cloud-sql"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud SQL Wait for Operation" \
|
||||
cloudsql \
|
||||
cloudsql
|
||||
|
||||
- id: "tidb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
|
||||
4
.github/labels.yaml
vendored
4
.github/labels.yaml
vendored
@@ -88,6 +88,10 @@
|
||||
color: 8befd7
|
||||
description: 'Status: reviewer is awaiting feedback or responses from the author before proceeding.'
|
||||
|
||||
- name: 'release candidate'
|
||||
color: 32CD32
|
||||
description: 'Use label to signal PR should be included in the next release.'
|
||||
|
||||
# Product Labels
|
||||
- name: 'product: bigquery'
|
||||
color: 5065c7
|
||||
|
||||
@@ -89,6 +89,9 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
|
||||
### Adding a New Tool
|
||||
|
||||
> [!NOTE]
|
||||
> Please follow the tool naming convention detailed [here](./DEVELOPER.md#tool-naming-conventions).
|
||||
|
||||
We recommend looking at an [example tool
|
||||
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
||||
|
||||
|
||||
41
DEVELOPER.md
41
DEVELOPER.md
@@ -44,6 +44,47 @@ Before you begin, ensure you have the following:
|
||||
curl http://127.0.0.1:5000
|
||||
```
|
||||
|
||||
### Tool Naming Conventions
|
||||
|
||||
This section details the purpose and conventions for MCP Toolbox's tools naming
|
||||
properties, **tool name** and **tool kind**.
|
||||
|
||||
```
|
||||
cancel_hotel: <- tool name
|
||||
kind: postgres-sql <- tool kind
|
||||
source: my_pg_source
|
||||
```
|
||||
|
||||
#### Tool Name
|
||||
|
||||
Tool name is the identifier used by a Large Language Model (LLM) to invoke a
|
||||
specific tool.
|
||||
* Custom tools: The user can define any name they want. The below guidelines
|
||||
do not apply.
|
||||
* Pre-built tools: The tool name is predefined and cannot be changed. It
|
||||
should follow the guidelines.
|
||||
|
||||
The following guidelines apply to tool names:
|
||||
* Should use underscores over hyphens (e.g., `list_collections` instead of
|
||||
`list-collections`).
|
||||
* Should not have the product name in the name (e.g., `list_collections` instead
|
||||
of `firestore_list_collections`).
|
||||
* Superficial changes are NOT considered as breaking (e.g., changing tool name).
|
||||
* Non-superficial changes MAY be considered breaking (e.g. adding new parameters
|
||||
to a function) until they can be validated through extensive testing to ensure
|
||||
they do not negatively impact agent's performances.
|
||||
|
||||
#### Tool Kind
|
||||
|
||||
Tool kind serves as a category or type that a user can assign to a tool.
|
||||
|
||||
The following guidelines apply to tool kinds:
|
||||
* Should user hyphens over underscores (e.g. `firestore-list-collections` or
|
||||
`firestore_list_colelctions`).
|
||||
* Should use product name in name (e.g. `firestore-list-collections` over
|
||||
`list-collections`).
|
||||
* Changes to tool kind are breaking changes and should be avoided.
|
||||
|
||||
## Testing
|
||||
|
||||
### Infrastructure
|
||||
|
||||
@@ -43,6 +43,7 @@ import (
|
||||
|
||||
// Import tool packages for side effect of registration
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/alloydbainl"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryconversationalanalytics"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryforecast"
|
||||
@@ -53,8 +54,9 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes"
|
||||
@@ -99,6 +101,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jexecutecypher"
|
||||
@@ -106,6 +109,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbaseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oceanbase/oceanbasesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||
@@ -122,10 +126,13 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
|
||||
@@ -1359,7 +1359,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"bigquery-database-tools": tools.ToolsetConfig{
|
||||
Name: "bigquery-database-tools",
|
||||
ToolNames: []string{"ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
ToolNames: []string{"analyze_contribution", "ask_data_insights", "execute_sql", "forecast", "get_dataset_info", "get_table_info", "list_dataset_ids", "list_table_ids"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1369,7 +1369,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"clickhouse-database-tools": tools.ToolsetConfig{
|
||||
Name: "clickhouse-database-tools",
|
||||
ToolNames: []string{"execute_sql"},
|
||||
ToolNames: []string{"execute_sql", "list_databases"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -4,7 +4,7 @@ go 1.24.6
|
||||
|
||||
require (
|
||||
github.com/googleapis/mcp-toolbox-sdk-go v0.3.0
|
||||
google.golang.org/genai v1.21.0
|
||||
google.golang.org/genai v1.23.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -75,6 +75,8 @@ See guides, [Connect from your IDE](../how-to/connect-ide/_index.md), for detail
|
||||
* `CLOUD_SQL_MYSQL_DATABASE`: The name of the database to connect to.
|
||||
* `CLOUD_SQL_MYSQL_USER`: The database username.
|
||||
* `CLOUD_SQL_MYSQL_PASSWORD`: The password for the database user.
|
||||
* `CLOUD_SQL_MYSQL_IP_TYPE`: The IP type i.e. "Public
|
||||
or "Private" (Default: Public).
|
||||
* **Permissions:**
|
||||
* **Cloud SQL Client** (`roles/cloudsql.client`) to connect to the instance.
|
||||
* Database-level permissions (e.g., `SELECT`, `INSERT`) are required to execute queries.
|
||||
|
||||
36
docs/en/resources/sources/alloydb-admin.md
Normal file
36
docs/en/resources/sources/alloydb-admin.md
Normal file
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: "AlloyDB Admin"
|
||||
linkTitle: "AlloyDB Admin"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
The "alloydb-admin" source provides a client for the AlloyDB API.
|
||||
aliases:
|
||||
- /resources/sources/alloydb-admin
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `alloydb-admin` source provides a client to interact with the [Google AlloyDB API](https://cloud.google.com/alloydb/docs/reference/rest). This allows tools to perform administrative tasks on AlloyDB resources, such as managing clusters, instances, and users.
|
||||
|
||||
Authentication can be handled in two ways:
|
||||
1. **Application Default Credentials (ADC):** By default, the source uses ADC to authenticate with the API.
|
||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source will expect an OAuth 2.0 access token to be provided by the client (e.g., a web browser) for each request.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-alloydb-admin:
|
||||
kind: alloy-admin
|
||||
|
||||
my-oauth-alloydb-admin:
|
||||
kind: alloydb-admin
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|----------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "alloydb-admin". |
|
||||
| useClientOAuth | boolean | false | If true, the source will use client-side OAuth for authorization. Otherwise, it will use Application Default Credentials. Defaults to `false`. |
|
||||
@@ -33,6 +33,9 @@ cluster][alloydb-free-trial].
|
||||
- [`postgres-execute-sql`](../tools/postgres/postgres-execute-sql.md)
|
||||
Run parameterized SQL statements in AlloyDB Postgres.
|
||||
|
||||
- [`postgres-list-tables`](../tools/postgres/postgres-list-tables.md)
|
||||
List tables in an AlloyDB for PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [AlloyDB using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/alloydb_pg_mcp/)
|
||||
|
||||
@@ -82,26 +82,7 @@ intend to run. Common roles include `roles/bigquery.user` (which includes
|
||||
permissions to run jobs and read data) or `roles/bigbigquery.dataViewer`.
|
||||
Follow this [guide][set-adc] to set up your ADC.
|
||||
|
||||
### Authentication via User's OAuth Access Token
|
||||
|
||||
If the `useClientOAuth` parameter is set to `true`, Toolbox will instead use the
|
||||
OAuth access token for authentication. This token is parsed from the
|
||||
`Authorization` header passed in with the tool invocation request. This method
|
||||
allows Toolbox to make queries to [BigQuery][bigquery-docs] on behalf of the
|
||||
client or the end-user.
|
||||
|
||||
When using this on-behalf-of authentication, you must ensure that the
|
||||
identity used has been granted the correct IAM permissions. Currently,
|
||||
this option is only supported by the following BigQuery tools:
|
||||
|
||||
- [`bigquery-sql`](../tools/bigquery/bigquery-sql.md)
|
||||
Run SQL queries directly against BigQuery datasets.
|
||||
|
||||
[iam-overview]: https://cloud.google.com/bigquery/docs/access-control
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
|
||||
## Example
|
||||
#### Example (ADC)
|
||||
|
||||
Initialize a BigQuery source that uses ADC:
|
||||
|
||||
@@ -111,8 +92,28 @@ sources:
|
||||
kind: "bigquery"
|
||||
project: "my-project-id"
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
```
|
||||
|
||||
### Authentication via User's OAuth Access Token
|
||||
|
||||
If the `useClientOAuth` parameter is set to `true`, Toolbox will instead use the
|
||||
OAuth access token for authentication. This token is parsed from the
|
||||
`Authorization` header passed in with the tool invocation request. This method
|
||||
allows Toolbox to make queries to [BigQuery][bigquery-docs] on behalf of the
|
||||
client or the end-user.
|
||||
|
||||
When using this on-behalf-of authentication, you must ensure that the
|
||||
identity used has been granted the correct IAM permissions.
|
||||
|
||||
[iam-overview]: <https://cloud.google.com/bigquery/docs/access-control>
|
||||
[adc]: <https://cloud.google.com/docs/authentication#adc>
|
||||
[set-adc]: <https://cloud.google.com/docs/authentication/provide-credentials-adc>
|
||||
|
||||
#### Example (Client OAuth)
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
|
||||
```yaml
|
||||
@@ -122,8 +123,13 @@ sources:
|
||||
project: "my-project-id"
|
||||
useClientOAuth: true
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
```
|
||||
|
||||
To connect to Gemini CLI using the client OAuth feature, you can follow this step-by-step [guide](../../samples/bigquery/bigquery_gemini_cli_client_oauth/_index.md).
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
@@ -131,4 +137,5 @@ sources:
|
||||
| kind | string | true | Must be "bigquery". |
|
||||
| project | string | true | Id of the Google Cloud project to use for billing and as the default project for BigQuery resources. |
|
||||
| location | string | false | Specifies the location (e.g., 'us', 'asia-northeast1') in which to run the query job. This location must match the location of any tables referenced in the query. Defaults to the table's location or 'US' if the location cannot be determined. [Learn More](https://cloud.google.com/bigquery/docs/locations) |
|
||||
| allowedDatasets | []string | false | An optional list of dataset IDs that tools using this source are allowed to access. If provided, any tool operation attempting to access a dataset not in this list will be rejected. To enforce this, two types of operations are also disallowed: 1) Dataset-level operations (e.g., `CREATE SCHEMA`), and 2) operations where table access cannot be statically analyzed (e.g., `EXECUTE IMMEDIATE`, `CREATE PROCEDURE`). If a single dataset is provided, it will be treated as the default for prebuilt tools. |
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. |
|
||||
|
||||
36
docs/en/resources/sources/cloud-monitoring.md
Normal file
36
docs/en/resources/sources/cloud-monitoring.md
Normal file
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: "Cloud Monitoring"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cloud-monitoring" source provides a client for the Cloud Monitoring API.
|
||||
aliases:
|
||||
- /resources/sources/cloud-monitoring
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-monitoring` source provides a client to interact with the [Google Cloud Monitoring API](https://cloud.google.com/monitoring/api). This allows tools to access cloud monitoring metrics explorer and run promql queries.
|
||||
|
||||
Authentication can be handled in two ways:
|
||||
1. **Application Default Credentials (ADC):** By default, the source uses ADC to authenticate with the API.
|
||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source will expect an OAuth 2.0 access token to be provided by the client (e.g., a web browser) for each request.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-cloud-monitoring:
|
||||
kind: cloud-monitoring
|
||||
|
||||
my-oauth-cloud-monitoring:
|
||||
kind: cloud-monitoring
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|----------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cloud-monitoring". |
|
||||
| useClientOAuth | boolean | false | If true, the source will use client-side OAuth for authorization. Otherwise, it will use Application Default Credentials. Defaults to `false`. |
|
||||
36
docs/en/resources/sources/cloud-sql-admin.md
Normal file
36
docs/en/resources/sources/cloud-sql-admin.md
Normal file
@@ -0,0 +1,36 @@
|
||||
---
|
||||
title: "Cloud SQL Admin"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cloud-sql-admin" source provides a client for the Cloud SQL Admin API.
|
||||
aliases:
|
||||
- /resources/sources/cloud-sql-admin
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-sql-admin` source provides a client to interact with the [Google Cloud SQL Admin API](https://cloud.google.com/sql/docs/mysql/admin-api/v1). This allows tools to perform administrative tasks on Cloud SQL instances, such as creating users and databases.
|
||||
|
||||
Authentication can be handled in two ways:
|
||||
1. **Application Default Credentials (ADC):** By default, the source uses ADC to authenticate with the API.
|
||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source will expect an OAuth 2.0 access token to be provided by the client (e.g., a web browser) for each request.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-cloud-sql-admin:
|
||||
kind: cloud-sql-admin
|
||||
|
||||
my-oauth-cloud-sql-admin:
|
||||
kind: cloud-sql-admin
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|----------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cloud-sql-admin". |
|
||||
| useClientOAuth | boolean | false | If true, the source will use client-side OAuth for authorization. Otherwise, it will use Application Default Credentials. Defaults to `false`. |
|
||||
@@ -28,6 +28,9 @@ to a database by following these instructions][csql-mysql-quickstart].
|
||||
- [`mysql-execute-sql`](../tools/mysql/mysql-execute-sql.md)
|
||||
Run parameterized SQL queries in Cloud SQL for MySQL.
|
||||
|
||||
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
|
||||
List tables in a Cloud SQL for MySQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [Cloud SQL for MySQL using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/cloud_sql_mysql_mcp/)
|
||||
|
||||
@@ -28,6 +28,9 @@ to a database by following these instructions][csql-pg-quickstart].
|
||||
- [`postgres-execute-sql`](../tools/postgres/postgres-execute-sql.md)
|
||||
Run parameterized SQL statements in PostgreSQL.
|
||||
|
||||
- [`postgres-list-tables`](../tools/postgres/postgres-list-tables.md)
|
||||
List tables in a PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [Cloud SQL for Postgres using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/cloud_sql_pg_mcp/)
|
||||
|
||||
@@ -22,6 +22,9 @@ reliability, performance, and ease of use.
|
||||
- [`mysql-execute-sql`](../tools/mysql/mysql-execute-sql.md)
|
||||
Run parameterized SQL queries in MySQL.
|
||||
|
||||
- [`mysql-list-tables`](../tools/mysql/mysql-list-tables.md)
|
||||
List tables in a MySQL database.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
@@ -23,6 +23,9 @@ reputation for reliability, feature robustness, and performance.
|
||||
- [`postgres-execute-sql`](../tools/postgres/postgres-execute-sql.md)
|
||||
Run parameterized SQL statements in PostgreSQL.
|
||||
|
||||
- [`postgres-list-tables`](../tools/postgres/postgres-list-tables.md)
|
||||
List tables in a PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [PostgreSQL using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/postgres_mcp/)
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
---
|
||||
title: "bigquery-analyze-contribution"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "bigquery-analyze-contribution" tool performs contribution analysis in BigQuery.
|
||||
aliases:
|
||||
- /resources/tools/bigquery-analyze-contribution
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `bigquery-analyze-contribution` tool performs contribution analysis in BigQuery by creating a temporary `CONTRIBUTION_ANALYSIS` model and then querying it with `ML.GET_INSIGHTS` to find top contributors for a given metric.
|
||||
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-analyze-contribution` takes the following parameters:
|
||||
|
||||
- **input_data** (string, required): The data that contain the test and control data to analyze. This can be a fully qualified BigQuery table ID (e.g., `my-project.my_dataset.my_table`) or a SQL query that returns the data.
|
||||
- **contribution_metric** (string, required): The name of the column that contains the metric to analyze. This can be SUM(metric_column_name), SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name) or SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name) depending the type of metric to analyze.
|
||||
- **is_test_col** (string, required): The name of the column that identifies whether a row is in the test or control group. The column must contain boolean values.
|
||||
- **dimension_id_cols** (array of strings, optional): An array of column names that uniquely identify each dimension.
|
||||
- **top_k_insights_by_apriori_support** (integer, optional): The number of top insights to return, ranked by apriori support. Default to '30'.
|
||||
- **pruning_method** (string, optional): The method to use for pruning redundant insights. Can be `'NO_PRUNING'` or `'PRUNE_REDUNDANT_INSIGHTS'`. Defaults to `'PRUNE_REDUNDANT_INSIGHTS'`.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
contribution_analyzer:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: my-bigquery-source
|
||||
description: Use this tool to run contribution analysis on a dataset in BigQuery.
|
||||
```
|
||||
|
||||
## Sample Prompt
|
||||
You can prepare a sample table following https://cloud.google.com/bigquery/docs/get-contribution-analysis-insights.
|
||||
And use the following sample prompts to call this tool:
|
||||
|
||||
- What drives the changes in sales in the table `bqml_tutorial.iowa_liquor_sales_sum_data`? Use the project id myproject.
|
||||
- Analyze the contribution for the `total_sales` metric in the table `bqml_tutorial.iowa_liquor_sales_sum_data`. The test group is identified by the `is_test` column. The dimensions are `store_name`, `city`, `vendor_name`, `category_name` and `item_description`.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:--------:|:------------:|------------------------------------------------------------|
|
||||
| kind | string | true | Must be "bigquery-analyze-contribution". |
|
||||
| source | string | true | Name of the source the tool should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -15,10 +15,19 @@ It's compatible with the following sources:
|
||||
|
||||
- [bigquery](../../sources/bigquery.md)
|
||||
|
||||
`bigquery-list-table-ids` takes a required `dataset` parameter to specify the dataset
|
||||
from which to list table IDs. It also optionally accepts a `project` parameter to
|
||||
define the Google Cloud project ID. If the `project` parameter is not provided, the
|
||||
tool defaults to using the project defined in the source configuration.
|
||||
`bigquery-list-table-ids` accepts the following parameters:
|
||||
- **`dataset`** (required): Specifies the dataset from which to list table IDs.
|
||||
- **`project`** (optional): Defines the Google Cloud project ID. If not provided,
|
||||
the tool defaults to the project from the source configuration.
|
||||
|
||||
The tool's behavior regarding these parameters is influenced by the
|
||||
`allowedDatasets` restriction on the `bigquery` source:
|
||||
- **Without `allowedDatasets` restriction:** The tool can list tables from any
|
||||
dataset specified by the `dataset` and `project` parameters.
|
||||
- **With `allowedDatasets` restriction:** Before listing tables, the tool verifies
|
||||
that the requested dataset is in the allowed list. If it is not, the request is
|
||||
denied. If only one dataset is specified in the `allowedDatasets` list, it
|
||||
will be used as the default value for the `dataset` parameter.
|
||||
|
||||
## Example
|
||||
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
---
|
||||
title: "clickhouse-list-databases"
|
||||
type: docs
|
||||
weight: 3
|
||||
description: >
|
||||
A "clickhouse-list-databases" tool lists all databases in a ClickHouse instance.
|
||||
aliases:
|
||||
- /resources/tools/clickhouse-list-databases
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `clickhouse-list-databases` tool lists all available databases in a
|
||||
ClickHouse instance. It's compatible with the [clickhouse](../../sources/clickhouse.md) source.
|
||||
|
||||
This tool executes the `SHOW DATABASES` command and returns a list of all
|
||||
databases accessible to the configured user, making it useful for database
|
||||
discovery and exploration tasks.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_clickhouse_databases:
|
||||
kind: clickhouse-list-databases
|
||||
source: my-clickhouse-instance
|
||||
description: List all available databases in the ClickHouse instance
|
||||
```
|
||||
|
||||
## Return Value
|
||||
|
||||
The tool returns an array of objects, where each object contains:
|
||||
- `name`: The name of the database
|
||||
|
||||
Example response:
|
||||
```json
|
||||
[
|
||||
{"name": "default"},
|
||||
{"name": "system"},
|
||||
{"name": "analytics"},
|
||||
{"name": "user_data"}
|
||||
]
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------------|:------------------:|:------------:|-----------------------------------------------------------|
|
||||
| kind | string | true | Must be "clickhouse-list-databases". |
|
||||
| source | string | true | Name of the ClickHouse source to list databases from. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| authRequired | array of string | false | Authentication services required to use this tool. |
|
||||
| parameters | array of Parameter | false | Parameters for the tool (typically not used). |
|
||||
@@ -1,7 +0,0 @@
|
||||
---
|
||||
title: "Cloud SQL"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Cloud SQL Control Plane.
|
||||
---
|
||||
@@ -1,43 +0,0 @@
|
||||
---
|
||||
title: "cloud-sql-wait-for-operation"
|
||||
type: docs
|
||||
weight: 10
|
||||
description: >
|
||||
Wait for a long-running Cloud SQL operation to complete.
|
||||
---
|
||||
|
||||
The `cloud-sql-wait-for-operation` tool is a utility tool that waits for a
|
||||
long-running Cloud SQL operation to complete. It does this by polling the Cloud
|
||||
SQL Admin API operation status endpoint until the operation is finished, using
|
||||
exponential backoff.
|
||||
|
||||
{{< notice info >}}
|
||||
This tool is intended for developer assistant workflows with human-in-the-loop
|
||||
and shouldn't be used for production agents.
|
||||
{{< /notice >}}
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
cloudsql-operations-get:
|
||||
kind: cloud-sql-wait-for-operation
|
||||
source: some-http-source
|
||||
description: "This will poll on operations API until the operation is done. For checking operation status we need projectId and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
|
||||
delay: 1s
|
||||
maxDelay: 4m
|
||||
multiplier: 2
|
||||
maxRetries: 10
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-wait-for-operation". |
|
||||
| source | string | true | The name of an `http` source to use for authentication. |
|
||||
| description | string | true | A description of the tool. |
|
||||
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
|
||||
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
|
||||
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
|
||||
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
|
||||
43
docs/en/resources/tools/mysql/mysql-list-tables.md
Normal file
43
docs/en/resources/tools/mysql/mysql-list-tables.md
Normal file
@@ -0,0 +1,43 @@
|
||||
---
|
||||
title: "mysql-list-tables"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "mysql-list-tables" tool lists schema information for all or specified tables in a MySQL database.
|
||||
aliases:
|
||||
- /resources/tools/mysql-list-tables
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `mysql-list-tables` tool retrieves schema information for all or specified tables in a MySQL database. It is compatible with any of the following sources:
|
||||
|
||||
- [cloud-sql-mysql](../../sources/cloud-sql-mysql.md)
|
||||
- [mysql](../../sources/mysql.md)
|
||||
|
||||
`mysql-list-tables` lists detailed schema information (object type, columns, constraints, indexes, triggers, owner, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, it lists all tables in user schemas. The output format can be set to `simple` which will return only the table names or `detailed` which is the default.
|
||||
|
||||
The tool takes the following input parameters:
|
||||
|
||||
| Parameter | Type | Description | Required |
|
||||
| :--------- | :----- | :--------------------------------------------------------------------------------------- | :------- |
|
||||
| `table_names` | string | Filters by a comma-separated list of names. By default, it lists all tables in user schemas. Default: `""` | No |
|
||||
| `output_format` | string | Indicate the output format of table schema. `simple` will return only the table names, `detailed` will return the full table information. Default: `detailed`. | No |
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
mysql_list_tables:
|
||||
kind: mysql-list-tables
|
||||
source: mysql-source
|
||||
description: Use this tool to retrieve schema information for all or specified tables. Output format can be simple (only table names) or detailed.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "mysql-list-tables". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the agent. |
|
||||
39
docs/en/resources/tools/postgres/postgres-list-tables.md
Normal file
39
docs/en/resources/tools/postgres/postgres-list-tables.md
Normal file
@@ -0,0 +1,39 @@
|
||||
---
|
||||
title: "postgres-list-tables"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "postgres-list-tables" tool lists schema information for all or specified tables in a Postgres database.
|
||||
aliases:
|
||||
- /resources/tools/postgres-list-tables
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `postgres-list-tables` tool retrieves schema information for all or specified tables in a Postgres database. It's compatible with any of the following sources:
|
||||
|
||||
- [alloydb-postgres](../../sources/alloydb-pg.md)
|
||||
- [cloud-sql-postgres](../../sources/cloud-sql-pg.md)
|
||||
- [postgres](../../sources/postgres.md)
|
||||
|
||||
`postgres-list-tables` lists detailed schema information (object type, columns, constraints, indexes, triggers, owner, comment) as JSON for user-created tables (ordinary or partitioned). The tool takes the following input parameters:
|
||||
* `table_names` (optional): Filters by a comma-separated list of names. By default, it lists all tables in user schemas.
|
||||
* `output_format` (optional): Indicate the output format of table schema. `simple` will return only the table names, `detailed` will return the full table information. Default: `detailed`.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
postgres_list_tables:
|
||||
kind: postgres-list-tables
|
||||
source: postgres-source
|
||||
description: Use this tool to retrieve schema information for all or specified tables. Output format can be simple (only table names) or detailed.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "postgres-list-tables". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the agent. |
|
||||
@@ -0,0 +1,130 @@
|
||||
---
|
||||
title: "Access BigQuery from Gemini-CLI with End-User Credentials"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
How to connect to BigQuery from Gemini-CLI with end-user credentials
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Gemini-CLI can be configured to get an OAuth access token from the Google OAuth endpoint, then send this token to MCP Toolbox as part of the request. MCP Toolbox can then use this token to authentincate with BigQuery. This enables each user to access Toolbox with thier own IAM identity for a Toolbox multi-tenancy use case.
|
||||
|
||||
{{< notice note >}}
|
||||
This feature requires Toolbox v0.14.0 or later.
|
||||
{{< /notice >}}
|
||||
|
||||
## Step 1: Register the OAuth on GCP
|
||||
|
||||
You first need to register the OAuth application following this [guide](register-oauth) to get a client ID and client secret.
|
||||
|
||||
## Step 2: Install and configure Toolbox
|
||||
|
||||
In this section, we will download Toolbox and run the Toolbox server.
|
||||
|
||||
1. Download the latest version of Toolbox as a binary:
|
||||
|
||||
{{< notice tip >}}
|
||||
Select the
|
||||
[correct binary](https://github.com/googleapis/genai-toolbox/releases)
|
||||
corresponding to your OS and CPU architecture.
|
||||
{{< /notice >}}
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.14.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
1. Create a `tools.yaml` file and include the following BigQuery source configuration:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-bigquery-client-auth-source:
|
||||
kind: "bigquery"
|
||||
project: "my-project-id"
|
||||
useClientOAuth: true
|
||||
# location: "US" # Optional: Specifies the location for query jobs.
|
||||
# allowedDatasets: # Optional: Restricts tool access to a specific list of datasets.
|
||||
# - "my_dataset_1"
|
||||
# - "other_project.my_dataset_2"
|
||||
```
|
||||
|
||||
1. Continue to configure one or more BigQuery tools. Here is a naive example to get started:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
naive-bq-tool:
|
||||
kind: "bigquery-sql"
|
||||
source: "my-bigquery-client-auth-source"
|
||||
description: Naive BQ Tool that returns 1.
|
||||
statement: |
|
||||
SELECT 1;
|
||||
```
|
||||
|
||||
1. Run the Toolbox server:
|
||||
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
|
||||
```
|
||||
|
||||
The toolbox server will begin listening on localhost port 5000. Leave it
|
||||
running and continue in another terminal.
|
||||
|
||||
Later, when it is time to shut everything down, you can quit the toolbox
|
||||
server with Ctrl-C in this terminal window.
|
||||
|
||||
## Step 3: Configure Gemini-CLI
|
||||
|
||||
1. Edit the file `~/.gemini/settings.json` to include the following configuration:
|
||||
|
||||
```json
|
||||
"mcpServers": {
|
||||
"toolbox": {
|
||||
"httpUrl": "<http://localhost:5000/mcp>", // Replace this with your Toolbox URL if deployed somewhere else.
|
||||
"oauth": {
|
||||
"enabled": true,
|
||||
"clientId": <YOUR_CLIENT_ID>,
|
||||
"clientSecret": <YOUR_CLIENT_SECRET>,
|
||||
"authorizationUrl": "<https://accounts.google.com/o/oauth2/v2/auth>",
|
||||
"tokenUrl": "<https://oauth2.googleapis.com/token>",
|
||||
"scopes": ["https://www.googleapis.com/auth/cloud-platform"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to substitue your client ID and client secret received from step 1.
|
||||
|
||||
1. Start Gemini-CLI:
|
||||
|
||||
```shell
|
||||
gemini-cli
|
||||
```
|
||||
|
||||
1. Authenticate with the command `/mcp auth toolbox`. Gemini-CLI will open up a
|
||||
browser where you will log in to your Google account.
|
||||
|
||||

|
||||
|
||||
1. Use Gemini-CLI with your tools. To test the naive Tool we configured previously, ask Gemini to run this Tool:
|
||||
|
||||
```text
|
||||
Call naive-bq-tool
|
||||
```
|
||||
|
||||
## Using Toolbox as a Shared Service
|
||||
|
||||
Toolbox can be run on another server as a shared service accessed by multiple
|
||||
users. We strongly recommend running toolbox behind a web proxy such as `nginx`
|
||||
which will provide SSL encryption. Google Cloud Run is another good way to run
|
||||
toolbox. You will connect to a service like `https://toolbox.example.com/mcp`.
|
||||
The proxy server will handle the SSL encryption and certificates. Then it will
|
||||
foward the requests to `http://localhost:5000/mcp` running in that environment.
|
||||
The details of the config are beyond the scope of this document, but will be
|
||||
familiar to your system administrators.
|
||||
|
||||
To use the shared service, just change the `localhost:5000` in the `httpUrl` in
|
||||
`~/.gemini/settings.json` to the host name and possibly the port of the shared
|
||||
service.
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
9
go.mod
9
go.mod
@@ -50,7 +50,7 @@ require (
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0
|
||||
go.opentelemetry.io/otel/trace v1.37.0
|
||||
golang.org/x/oauth2 v0.31.0
|
||||
google.golang.org/api v0.248.0
|
||||
google.golang.org/api v0.249.0
|
||||
google.golang.org/genproto v0.0.0-20250826171959-ef028d996bc1
|
||||
modernc.org/sqlite v1.38.2
|
||||
)
|
||||
@@ -65,7 +65,6 @@ require (
|
||||
github.com/segmentio/asm v1.2.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect
|
||||
gonum.org/v1/gonum v0.16.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -82,7 +81,7 @@ require (
|
||||
cloud.google.com/go/trace v1.11.6 // indirect
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.3 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
@@ -100,7 +99,7 @@ require (
|
||||
github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.1 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -175,7 +174,7 @@ require (
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect
|
||||
google.golang.org/grpc v1.74.2 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -661,8 +661,8 @@ github.com/ClickHouse/clickhouse-go/v2 v2.40.1 h1:PbwsHBgqXRydU7jKULD1C8CHmifczf
|
||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.1/go.mod h1:GDzSBLVhladVm8V01aEB36IoBOVLLICfyeuiIp/8Ezc=
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.3 h1:2afWGsMzkIcN8Qm4mgPJKZWyroE5QBszMiDMYEBrnfw=
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.3/go.mod h1:dppbR7CwXD4pgtV9t3wD1812RaLDcBjtblcDF5f1vI0=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 h1:ErKg/3iS1AKcTkf3yixlZ54f9U1rljCkQyEXWUnIUxc=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0/go.mod h1:yAZHSGnqScoU556rBOVkwLze6WP5N+U11RHuWaGVxwY=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 h1:UQUsRi8WTzhZntp5313l+CHIAT95ojUI2lpP/ExlZa4=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0/go.mod h1:Cz6ft6Dkn3Et6l2v2a9/RpN7epQ1GtDlO6lj8bEcOvw=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 h1:owcC2UnmsZycprQ5RfRgjydWhuoxg71LUfyiQdijZuM=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0/go.mod h1:ZPpqegjbE99EPKsu3iUWV22A04wzGPcAY/ziSIQEEgs=
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0 h1:YVtMlmfRUTaWs3+1acwMBp7rBUo6zrxl6Kn13/R9YW4=
|
||||
@@ -866,8 +866,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-goquery/goquery v1.0.1 h1:kpchVA1LdOFWdRpkDPESVdlb1JQI6ixsJ5MiNUITO7U=
|
||||
github.com/go-goquery/goquery v1.0.1/go.mod h1:W5s8OWbqWf6lG0LkXWBeh7U1Y/X5XTI0Br65MHF8uJk=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5 h1:M6T8+mKZl/+fNNuFHvGIzDz7BTLQPIounk/b9dw3AaE=
|
||||
github.com/go-jose/go-jose/v4 v4.0.5/go.mod h1:s3P1lRrkT8igV8D9OjyL4WRyHvjB6a4JSllnOrmmBOA=
|
||||
github.com/go-jose/go-jose/v4 v4.1.1 h1:JYhSgy4mXXzAdF3nUx3ygx347LRXJRrpgyU3adRmkAI=
|
||||
github.com/go-jose/go-jose/v4 v4.1.1/go.mod h1:BdsZGqgdO3b6tTc6LSE56wcDbMMLuPsw5d4ZD5f94kA=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
||||
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
|
||||
@@ -1823,8 +1823,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/
|
||||
google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI=
|
||||
google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0=
|
||||
google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg=
|
||||
google.golang.org/api v0.248.0 h1:hUotakSkcwGdYUqzCRc5yGYsg4wXxpkKlW5ryVqvC1Y=
|
||||
google.golang.org/api v0.248.0/go.mod h1:yAFUAF56Li7IuIQbTFoLwXTCI6XCFKueOlS7S9e4F9k=
|
||||
google.golang.org/api v0.249.0 h1:0VrsWAKzIZi058aeq+I86uIXbNhm9GxSHpbmZ92a38w=
|
||||
google.golang.org/api v0.249.0/go.mod h1:dGk9qyI0UYPwO/cjt2q06LG/EhUpwZGdAbYF14wHHrQ=
|
||||
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
|
||||
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
|
||||
@@ -2012,8 +2012,8 @@ google.golang.org/grpc v1.52.3/go.mod h1:pu6fVzoFb+NBYNAvQL08ic+lvB2IojljRYuun5v
|
||||
google.golang.org/grpc v1.53.0/go.mod h1:OnIrk0ipVdj4N5d9IUoFUx72/VlD7+jUsHwZgwSMQpw=
|
||||
google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g=
|
||||
google.golang.org/grpc v1.56.3/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s=
|
||||
google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4=
|
||||
google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM=
|
||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw=
|
||||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8=
|
||||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0=
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
var expectedToolSources = []string{
|
||||
"alloydb-postgres-admin",
|
||||
"alloydb-postgres-observability",
|
||||
"alloydb-postgres",
|
||||
"bigquery",
|
||||
"clickhouse",
|
||||
@@ -85,6 +86,7 @@ func TestLoadPrebuiltToolYAMLs(t *testing.T) {
|
||||
|
||||
func TestGetPrebuiltTool(t *testing.T) {
|
||||
alloydb_admin_config, _ := Get("alloydb-postgres-admin")
|
||||
alloydb_observability_config, _ := Get("alloydb-postgres-observability")
|
||||
alloydb_config, _ := Get("alloydb-postgres")
|
||||
bigquery_config, _ := Get("bigquery")
|
||||
clickhouse_config, _ := Get("clickhouse")
|
||||
@@ -106,6 +108,9 @@ func TestGetPrebuiltTool(t *testing.T) {
|
||||
if len(alloydb_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch alloydb prebuilt tools yaml")
|
||||
}
|
||||
if len(alloydb_observability_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch alloydb-observability prebuilt tools yaml")
|
||||
}
|
||||
if len(bigquery_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch bigquery prebuilt tools yaml")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
sources:
|
||||
cloud-monitoring-source:
|
||||
kind: cloud-monitoring
|
||||
tools:
|
||||
get_system_metrics:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
source: cloud-monitoring-source
|
||||
description: |
|
||||
Fetches system level cloudmonitoring data (timeseries metrics) for an AlloyDB cluster, instance.
|
||||
To use this tool, you must provide the Google Cloud `projectID` and a PromQL `query`.
|
||||
|
||||
Generate the PromQL `query` for AlloyDB system metrics using the provided metrics and rules. Get labels like `cluster_id` and `instance_id` from the user's intent.
|
||||
|
||||
Defaults:
|
||||
1. Interval: Use a default interval of `5m` for `_over_time` aggregation functions unless a different window is specified by the user.
|
||||
|
||||
PromQL Query Examples:
|
||||
1. Basic Time Series: `avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance"}[5m])`
|
||||
2. Top K: `topk(30, avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance"}[5m]))`
|
||||
3. Mean: `avg(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="my-instance","cluster_id"="my-cluster"}[5m]))`
|
||||
4. Minimum: `min(min_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
5. Maximum: `max(max_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
6. Sum: `sum(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
7. Count streams: `count(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
8. Percentile with groupby on instanceid, clusterid: `quantile by ("instance_id","cluster_id")(0.99,avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","cluster_id"="my-cluster","instance_id"="my-instance"}[5m]))`
|
||||
|
||||
Available Metrics List: metricname. description. monitored resource. labels
|
||||
1. `alloydb.googleapis.com/instance/cpu/average_utilization`: The percentage of CPU being used on an instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
2. `alloydb.googleapis.com/instance/cpu/maximum_utilization`: Maximum CPU utilization across all currently serving nodes of the instance from 0 to 100. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
3. `alloydb.googleapis.com/cluster/storage/usage`: The total AlloyDB storage in bytes across the entire cluster. `alloydb.googleapis.com/Cluster`. `cluster_id`.
|
||||
4. `alloydb.googleapis.com/instance/postgres/replication/replicas`: The number of read replicas connected to the primary instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `state`, `replica_instance_id`.
|
||||
5. `alloydb.googleapis.com/instance/postgres/replication/maximum_lag`: The maximum replication time lag calculated across all serving read replicas of the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `replica_instance_id`.
|
||||
6. `alloydb.googleapis.com/instance/memory/min_available_memory`: The minimum available memory across all currently serving nodes of the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
7. `alloydb.googleapis.com/instance/postgres/instances`: The number of nodes in the instance, along with their status, which can be either up or down. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `status`.
|
||||
8. `alloydb.googleapis.com/database/postgresql/tuples`: Number of tuples (rows) by state per database in the instance. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`, `state`.
|
||||
9. `alloydb.googleapis.com/database/postgresql/temp_bytes_written_for_top_databases`: The total amount of data(in bytes) written to temporary files by the queries per database for top 500 dbs. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
10. `alloydb.googleapis.com/database/postgresql/temp_files_written_for_top_databases`: The number of temporary files used for writing data per database while performing internal algorithms like join, sort etc for top 500 dbs. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
11. `alloydb.googleapis.com/database/postgresql/inserted_tuples_count_for_top_databases`: The total number of rows inserted per db for top 500 dbs as a result of the queries in the instance. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
12. `alloydb.googleapis.com/database/postgresql/updated_tuples_count_for_top_databases`: The total number of rows updated per db for top 500 dbs as a result of the queries in the instance. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
13. `alloydb.googleapis.com/database/postgresql/deleted_tuples_count_for_top_databases`: The total number of rows deleted per db for top 500 dbs as a result of the queries in the instance. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
14. `alloydb.googleapis.com/database/postgresql/backends_for_top_databases`: The current number of connections per database to the instance for top 500 dbs. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
15. `alloydb.googleapis.com/instance/postgresql/backends_by_state`: The current number of connections to the instance grouped by the state like idle, active, idle_in_transaction, idle_in_transaction_aborted, disabled, and fastpath_function_call. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `state`.
|
||||
16. `alloydb.googleapis.com/instance/postgresql/backends_for_top_applications`: The current number of connections to the AlloyDB instance, grouped by applications for top 500 applications. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `application_name`.
|
||||
17. `alloydb.googleapis.com/database/postgresql/new_connections_for_top_databases`: Total number of new connections added per database for top 500 databases to the instance. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
18. `alloydb.googleapis.com/database/postgresql/deadlock_count_for_top_databases`: Total number of deadlocks detected in the instance per database for top 500 dbs. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`.
|
||||
19. `alloydb.googleapis.com/database/postgresql/statements_executed_count`: Total count of statements executed in the instance per database per operation_type. `alloydb.googleapis.com/Database`. `cluster_id`, `instance_id`, `database`, `operation_type`.
|
||||
20. `alloydb.googleapis.com/instance/postgresql/returned_tuples_count`: Number of rows scanned while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
21. `alloydb.googleapis.com/instance/postgresql/fetched_tuples_count`: Number of rows fetched while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
22. `alloydb.googleapis.com/instance/postgresql/updated_tuples_count`: Number of rows updated while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
23. `alloydb.googleapis.com/instance/postgresql/inserted_tuples_count`: Number of rows inserted while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
24. `alloydb.googleapis.com/instance/postgresql/deleted_tuples_count`: Number of rows deleted while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
25. `alloydb.googleapis.com/instance/postgresql/written_tuples_count`: Number of rows written while processing the queries in the instance since the last sample. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
26. `alloydb.googleapis.com/instance/postgresql/deadlock_count`: Number of deadlocks detected in the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
27. `alloydb.googleapis.com/instance/postgresql/blks_read`: Number of blocks read by Postgres that were not in the buffer cache. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
28. `alloydb.googleapis.com/instance/postgresql/blks_hit`: Number of times Postgres found the requested block in the buffer cache. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
29. `alloydb.googleapis.com/instance/postgresql/temp_bytes_written_count`: The total amount of data(in bytes) written to temporary files by the queries while performing internal algorithms like join, sort etc. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
30. `alloydb.googleapis.com/instance/postgresql/temp_files_written_count`: The number of temporary files used for writing data in the instance while performing internal algorithms like join, sort etc. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
31. `alloydb.googleapis.com/instance/postgresql/new_connections_count`: The number new connections added to the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
32. `alloydb.googleapis.com/instance/postgresql/wait_count`: Total number of times processes waited for each wait event in the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `wait_event_type`, `wait_event_name`.
|
||||
33. `alloydb.googleapis.com/instance/postgresql/wait_time`: Total elapsed wait time for each wait event in the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`, `wait_event_type`, `wait_event_name`.
|
||||
34. `alloydb.googleapis.com/instance/postgres/transaction_count`: The number of committed and rolled back transactions across all serving nodes of the instance. `alloydb.googleapis.com/Instance`. `cluster_id`, `instance_id`.
|
||||
|
||||
get_query_metrics:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
source: cloud-monitoring-source
|
||||
description: |
|
||||
Fetches query level cloudmonitoring data (timeseries metrics) for queries running in an AlloyDB instance.
|
||||
To use this tool, you must provide the Google Cloud `projectId` and a PromQL `query`.
|
||||
|
||||
Generate the PromQL `query` for AlloyDB query metrics using the provided metrics and rules. Get labels like `cluster_id`, `instance_id`, and `query_hash` from the user's intent. If `query_hash` is provided, use the per-query metrics.
|
||||
|
||||
Defaults:
|
||||
1. Interval: Use a default interval of `5m` for `_over_time` aggregation functions unless a different window is specified by the user.
|
||||
|
||||
PromQL Query Examples:
|
||||
1. Basic Time Series: `avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance"}[5m])`
|
||||
2. Top K: `topk(30, avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance"}[5m]))`
|
||||
3. Mean: `avg(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="my-instance","cluster_id"="my-cluster"}[5m]))`
|
||||
4. Minimum: `min(min_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
5. Maximum: `max(max_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
6. Sum: `sum(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
7. Count streams: `count(avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","instance_id"="alloydb-instance","cluster_id"="alloydb-cluster"}[5m]))`
|
||||
8. Percentile with groupby on instanceid, clusterid: `quantile by ("instance_id","cluster_id")(0.99,avg_over_time({"__name__"="alloydb.googleapis.com/instance/cpu/average_utilization","monitored_resource"="alloydb.googleapis.com/Instance","cluster_id"="my-cluster","instance_id"="my-instance"}[5m]))`
|
||||
|
||||
Available Metrics List: metricname. description. monitored resource. labels. aggregate is the aggregated values for all query stats, Use aggregate metrics if query id is not provided. For perquery metrics do not fetch querystring unless specified by user specifically. Have the aggregation on query hash to avoid fetching the querystring. Do not use latency metrics for anything.
|
||||
1. `alloydb.googleapis.com/database/postgresql/insights/aggregate/latencies`: Aggregated query latency distribution. `alloydb.googleapis.com/Database`. `user`, `client_addr`.
|
||||
2. `alloydb.googleapis.com/database/postgresql/insights/aggregate/execution_time`: Accumulated aggregated query execution time since the last sample. `alloydb.googleapis.com/Database`. `user`, `client_addr`.
|
||||
3. `alloydb.googleapis.com/database/postgresql/insights/aggregate/io_time`: Accumulated aggregated IO time since the last sample. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `io_type`.
|
||||
4. `alloydb.googleapis.com/database/postgresql/insights/aggregate/lock_time`: Accumulated aggregated lock wait time since the last sample. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `lock_type`.
|
||||
5. `alloydb.googleapis.com/database/postgresql/insights/aggregate/row_count`: Aggregated number of retrieved or affected rows since the last sample. `alloydb.googleapis.com/Database`. `user`, `client_addr`.
|
||||
6. `alloydb.googleapis.com/database/postgresql/insights/aggregate/shared_blk_access_count`: Aggregated shared blocks accessed by statement execution. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `access_type`.
|
||||
7. `alloydb.googleapis.com/database/postgresql/insights/perquery/latencies`: Per query latency distribution. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `querystring`, `query_hash`.
|
||||
8. `alloydb.googleapis.com/database/postgresql/insights/perquery/execution_time`: Accumulated execution times per user per database per query. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `querystring`, `query_hash`.
|
||||
9. `alloydb.googleapis.com/database/postgresql/insights/perquery/io_time`: Accumulated IO time since the last sample per query. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `io_type`, `querystring`, `query_hash`.
|
||||
10. `alloydb.googleapis.com/database/postgresql/insights/perquery/lock_time`: Accumulated lock wait time since the last sample per query. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `lock_type`, `querystring`, `query_hash`.
|
||||
11. `alloydb.googleapis.com/database/postgresql/insights/perquery/row_count`: The number of retrieved or affected rows since the last sample per query. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `querystring`, `query_hash`.
|
||||
12. `alloydb.googleapis.com/database/postgresql/insights/perquery/shared_blk_access_count`: Shared blocks accessed by statement execution per query. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `access_type`, `querystring`, `query_hash`.
|
||||
13. `alloydb.googleapis.com/database/postgresql/insights/pertag/latencies`: Query latency distribution. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`.
|
||||
14. `alloydb.googleapis.com/database/postgresql/insights/pertag/execution_time`: Accumulated execution times since the last sample. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`.
|
||||
15. `alloydb.googleapis.com/database/postgresql/insights/pertag/io_time`: Accumulated IO time since the last sample per tag. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `io_type`, `tag_hash`.
|
||||
16. `alloydb.googleapis.com/database/postgresql/insights/pertag/lock_time`: Accumulated lock wait time since the last sample per tag. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `lock_type`, `tag_hash`.
|
||||
17. `alloydb.googleapis.com/database/postgresql/insights/pertag/shared_blk_access_count`: Shared blocks accessed by statement execution per tag. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `access_type`, `tag_hash`.
|
||||
18. `alloydb.googleapis.com/database/postgresql/insights/pertag/row_count`: The number of retrieved or affected rows since the last sample per tag. `alloydb.googleapis.com/Database`. `user`, `client_addr`, `action`, `application`, `controller`, `db_driver`, `framework`, `route`, `tag_hash`.
|
||||
|
||||
toolsets:
|
||||
alloydb-postgres-cloud-monitoring-tools:
|
||||
- get_system_metrics
|
||||
- get_query_metrics
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
alloydb-pg-source:
|
||||
kind: "alloydb-postgres"
|
||||
@@ -30,93 +31,9 @@ tools:
|
||||
description: Use this tool to execute sql.
|
||||
|
||||
list_tables:
|
||||
kind: postgres-sql
|
||||
kind: postgres-list-tables
|
||||
source: alloydb-pg-source
|
||||
description: "Lists detailed schema information (object type, columns, constraints, indexes, triggers, owner, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas."
|
||||
statement: |
|
||||
WITH desired_relkinds AS (
|
||||
SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE'
|
||||
),
|
||||
table_info AS (
|
||||
SELECT
|
||||
t.oid AS table_oid,
|
||||
ns.nspname AS schema_name,
|
||||
t.relname AS table_name,
|
||||
pg_get_userbyid(t.relowner) AS table_owner,
|
||||
obj_description(t.oid, 'pg_class') AS table_comment,
|
||||
t.relkind AS object_kind
|
||||
FROM
|
||||
pg_class t
|
||||
JOIN
|
||||
pg_namespace ns ON ns.oid = t.relnamespace
|
||||
CROSS JOIN desired_relkinds dk
|
||||
WHERE
|
||||
t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p')
|
||||
AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names
|
||||
AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
),
|
||||
columns_info AS (
|
||||
SELECT
|
||||
att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type,
|
||||
att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment
|
||||
FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum
|
||||
JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped
|
||||
),
|
||||
constraints_info AS (
|
||||
SELECT
|
||||
con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition,
|
||||
CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns,
|
||||
NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns
|
||||
FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid
|
||||
),
|
||||
indexes_info AS (
|
||||
SELECT
|
||||
idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition,
|
||||
idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method,
|
||||
(SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns
|
||||
FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid
|
||||
),
|
||||
triggers_info AS (
|
||||
SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state
|
||||
FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal
|
||||
)
|
||||
SELECT
|
||||
ti.schema_name,
|
||||
ti.table_name AS object_name,
|
||||
CASE
|
||||
WHEN $2 = 'simple' THEN
|
||||
-- IF format is 'simple', return basic JSON
|
||||
json_build_object('name', ti.table_name)
|
||||
ELSE
|
||||
json_build_object(
|
||||
'schema_name', ti.schema_name,
|
||||
'object_name', ti.table_name,
|
||||
'object_type', CASE ti.object_kind
|
||||
WHEN 'r' THEN 'TABLE'
|
||||
WHEN 'p' THEN 'PARTITIONED TABLE'
|
||||
ELSE ti.object_kind::text -- Should not happen due to WHERE clause
|
||||
END,
|
||||
'owner', ti.table_owner,
|
||||
'comment', ti.table_comment,
|
||||
'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json),
|
||||
'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json),
|
||||
'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json),
|
||||
'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json)
|
||||
)
|
||||
END AS object_details
|
||||
FROM table_info ti ORDER BY ti.schema_name, ti.table_name;
|
||||
parameters:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
- name: output_format
|
||||
type: string
|
||||
description: "Optional: Use 'simple' to return table names only or use 'detailed' to return the full information schema."
|
||||
default: "detailed"
|
||||
|
||||
toolsets:
|
||||
alloydb-postgres-database-tools:
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
bigquery-source:
|
||||
kind: "bigquery"
|
||||
@@ -18,6 +19,11 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: bigquery-source
|
||||
description: Use this tool to analyze the contribution about changes to key metrics in multi-dimensional data.
|
||||
|
||||
ask_data_insights:
|
||||
kind: bigquery-conversational-analytics
|
||||
source: bigquery-source
|
||||
@@ -58,6 +64,7 @@ tools:
|
||||
|
||||
toolsets:
|
||||
bigquery-database-tools:
|
||||
- analyze_contribution
|
||||
- ask_data_insights
|
||||
- execute_sql
|
||||
- forecast
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
sources:
|
||||
clickhouse-source:
|
||||
kind: clickhouse
|
||||
@@ -14,6 +27,12 @@ tools:
|
||||
source: clickhouse-source
|
||||
description: Use this tool to execute SQL.
|
||||
|
||||
list_databases:
|
||||
kind: clickhouse-list-databases
|
||||
source: clickhouse-source
|
||||
description: Use this tool to list all databases in ClickHouse.
|
||||
|
||||
toolsets:
|
||||
clickhouse-database-tools:
|
||||
- execute_sql
|
||||
- list_databases
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
cloud-sql-mysql-source:
|
||||
kind: cloud-sql-mysql
|
||||
@@ -7,174 +21,16 @@ sources:
|
||||
database: ${CLOUD_SQL_MYSQL_DATABASE}
|
||||
user: ${CLOUD_SQL_MYSQL_USER}
|
||||
password: ${CLOUD_SQL_MYSQL_PASSWORD}
|
||||
ipType: ${CLOUD_SQL_MYSQL_IP_TYPE:PUBLIC}
|
||||
tools:
|
||||
execute_sql:
|
||||
kind: mysql-execute-sql
|
||||
source: cloud-sql-mysql-source
|
||||
description: Use this tool to execute SQL.
|
||||
list_tables:
|
||||
kind: mysql-sql
|
||||
kind: mysql-list-tables
|
||||
source: cloud-sql-mysql-source
|
||||
description: "Lists detailed schema information (object type, columns, constraints, indexes, triggers, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas."
|
||||
statement: |
|
||||
SELECT
|
||||
T.TABLE_SCHEMA AS schema_name,
|
||||
T.TABLE_NAME AS object_name,
|
||||
CASE
|
||||
WHEN @output_format = 'simple' THEN
|
||||
JSON_OBJECT('name', T.TABLE_NAME)
|
||||
ELSE
|
||||
CONVERT( JSON_OBJECT(
|
||||
'schema_name', T.TABLE_SCHEMA,
|
||||
'object_name', T.TABLE_NAME,
|
||||
'object_type', 'TABLE',
|
||||
'owner', (
|
||||
SELECT
|
||||
IFNULL(U.GRANTEE, 'N/A')
|
||||
FROM
|
||||
INFORMATION_SCHEMA.SCHEMA_PRIVILEGES U
|
||||
WHERE
|
||||
U.TABLE_SCHEMA = T.TABLE_SCHEMA
|
||||
LIMIT 1
|
||||
),
|
||||
'comment', IFNULL(T.TABLE_COMMENT, ''),
|
||||
'columns', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'column_name', C.COLUMN_NAME,
|
||||
'data_type', C.COLUMN_TYPE,
|
||||
'ordinal_position', C.ORDINAL_POSITION,
|
||||
'is_not_nullable', IF(C.IS_NULLABLE = 'NO', TRUE, FALSE),
|
||||
'column_default', C.COLUMN_DEFAULT,
|
||||
'column_comment', IFNULL(C.COLUMN_COMMENT, '')
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.COLUMNS C
|
||||
WHERE
|
||||
C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_NAME = T.TABLE_NAME
|
||||
ORDER BY C.ORDINAL_POSITION
|
||||
),
|
||||
'constraints', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'constraint_name', TC.CONSTRAINT_NAME,
|
||||
'constraint_type',
|
||||
CASE TC.CONSTRAINT_TYPE
|
||||
WHEN 'PRIMARY KEY' THEN 'PRIMARY KEY'
|
||||
WHEN 'FOREIGN KEY' THEN 'FOREIGN KEY'
|
||||
WHEN 'UNIQUE' THEN 'UNIQUE'
|
||||
ELSE TC.CONSTRAINT_TYPE
|
||||
END,
|
||||
'constraint_definition', '',
|
||||
'constraint_columns', (
|
||||
SELECT
|
||||
IFNULL(JSON_ARRAYAGG(KCU.COLUMN_NAME), JSON_ARRAY())
|
||||
FROM
|
||||
INFORMATION_SCHEMA.KEY_COLUMN_USAGE KCU
|
||||
WHERE
|
||||
KCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND KCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND KCU.TABLE_NAME = TC.TABLE_NAME
|
||||
ORDER BY KCU.ORDINAL_POSITION
|
||||
),
|
||||
'foreign_key_referenced_table', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY', RC.REFERENCED_TABLE_NAME, NULL),
|
||||
'foreign_key_referenced_columns', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY',
|
||||
(SELECT IFNULL(JSON_ARRAYAGG(FKCU.REFERENCED_COLUMN_NAME), JSON_ARRAY())
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE FKCU
|
||||
WHERE FKCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND FKCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND FKCU.TABLE_NAME = TC.TABLE_NAME
|
||||
AND FKCU.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
ORDER BY FKCU.ORDINAL_POSITION),
|
||||
NULL
|
||||
)
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC
|
||||
LEFT JOIN
|
||||
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS RC
|
||||
ON TC.CONSTRAINT_SCHEMA = RC.CONSTRAINT_SCHEMA
|
||||
AND TC.CONSTRAINT_NAME = RC.CONSTRAINT_NAME
|
||||
AND TC.TABLE_NAME = RC.TABLE_NAME
|
||||
WHERE
|
||||
TC.TABLE_SCHEMA = T.TABLE_SCHEMA AND TC.TABLE_NAME = T.TABLE_NAME
|
||||
),
|
||||
'indexes', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'index_name', IndexData.INDEX_NAME,
|
||||
'is_unique', IF(IndexData.NON_UNIQUE = 0, TRUE, FALSE),
|
||||
'is_primary', IF(IndexData.INDEX_NAME = 'PRIMARY', TRUE, FALSE),
|
||||
'index_columns', IFNULL(IndexData.INDEX_COLUMNS_ARRAY, JSON_ARRAY())
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM (
|
||||
SELECT
|
||||
S.TABLE_SCHEMA,
|
||||
S.TABLE_NAME,
|
||||
S.INDEX_NAME,
|
||||
MIN(S.NON_UNIQUE) AS NON_UNIQUE, -- Aggregate NON_UNIQUE here to get unique status for the index
|
||||
JSON_ARRAYAGG(S.COLUMN_NAME) AS INDEX_COLUMNS_ARRAY -- Aggregate columns into an array for this index
|
||||
FROM
|
||||
INFORMATION_SCHEMA.STATISTICS S
|
||||
WHERE
|
||||
S.TABLE_SCHEMA = T.TABLE_SCHEMA AND S.TABLE_NAME = T.TABLE_NAME
|
||||
GROUP BY
|
||||
S.TABLE_SCHEMA, S.TABLE_NAME, S.INDEX_NAME
|
||||
) AS IndexData
|
||||
ORDER BY IndexData.INDEX_NAME
|
||||
),
|
||||
'triggers', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'trigger_name', TR.TRIGGER_NAME,
|
||||
'trigger_definition', TR.ACTION_STATEMENT
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TRIGGERS TR
|
||||
WHERE
|
||||
TR.EVENT_OBJECT_SCHEMA = T.TABLE_SCHEMA AND TR.EVENT_OBJECT_TABLE = T.TABLE_NAME
|
||||
ORDER BY TR.TRIGGER_NAME
|
||||
)
|
||||
) USING utf8mb4)
|
||||
END AS object_details
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLES T
|
||||
CROSS JOIN (SELECT @table_names := ?, @output_format := ?) AS variables
|
||||
WHERE
|
||||
T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')
|
||||
AND (NULLIF(TRIM(@table_names), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, @table_names))
|
||||
AND T.TABLE_TYPE = 'BASE TABLE'
|
||||
ORDER BY
|
||||
T.TABLE_SCHEMA, T.TABLE_NAME;
|
||||
parameters:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
default: ""
|
||||
- name: output_format
|
||||
type: string
|
||||
description: "Optional: Use 'simple' to return table names only or use 'detailed' to return the full information schema."
|
||||
default: "detailed"
|
||||
|
||||
toolsets:
|
||||
cloud-sql-mysql-database-tools:
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
cloudsql-pg-source:
|
||||
kind: cloud-sql-postgres
|
||||
@@ -29,93 +30,9 @@ tools:
|
||||
description: Use this tool to execute sql.
|
||||
|
||||
list_tables:
|
||||
kind: postgres-sql
|
||||
kind: postgres-list-tables
|
||||
source: cloudsql-pg-source
|
||||
description: "Lists detailed schema information (object type, columns, constraints, indexes, triggers, owner, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas."
|
||||
statement: |
|
||||
WITH desired_relkinds AS (
|
||||
SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE'
|
||||
),
|
||||
table_info AS (
|
||||
SELECT
|
||||
t.oid AS table_oid,
|
||||
ns.nspname AS schema_name,
|
||||
t.relname AS table_name,
|
||||
pg_get_userbyid(t.relowner) AS table_owner,
|
||||
obj_description(t.oid, 'pg_class') AS table_comment,
|
||||
t.relkind AS object_kind
|
||||
FROM
|
||||
pg_class t
|
||||
JOIN
|
||||
pg_namespace ns ON ns.oid = t.relnamespace
|
||||
CROSS JOIN desired_relkinds dk
|
||||
WHERE
|
||||
t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p')
|
||||
AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names
|
||||
AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
),
|
||||
columns_info AS (
|
||||
SELECT
|
||||
att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type,
|
||||
att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment
|
||||
FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum
|
||||
JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped
|
||||
),
|
||||
constraints_info AS (
|
||||
SELECT
|
||||
con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition,
|
||||
CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns,
|
||||
NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns
|
||||
FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid
|
||||
),
|
||||
indexes_info AS (
|
||||
SELECT
|
||||
idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition,
|
||||
idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method,
|
||||
(SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns
|
||||
FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid
|
||||
),
|
||||
triggers_info AS (
|
||||
SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state
|
||||
FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal
|
||||
)
|
||||
SELECT
|
||||
ti.schema_name,
|
||||
ti.table_name AS object_name,
|
||||
CASE
|
||||
WHEN $2 = 'simple' THEN
|
||||
-- IF format is 'simple', return basic JSON
|
||||
json_build_object('name', ti.table_name)
|
||||
ELSE
|
||||
json_build_object(
|
||||
'schema_name', ti.schema_name,
|
||||
'object_name', ti.table_name,
|
||||
'object_type', CASE ti.object_kind
|
||||
WHEN 'r' THEN 'TABLE'
|
||||
WHEN 'p' THEN 'PARTITIONED TABLE'
|
||||
ELSE ti.object_kind::text -- Should not happen due to WHERE clause
|
||||
END,
|
||||
'owner', ti.table_owner,
|
||||
'comment', ti.table_comment,
|
||||
'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json),
|
||||
'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json),
|
||||
'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json),
|
||||
'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json)
|
||||
)
|
||||
END AS object_details
|
||||
FROM table_info ti ORDER BY ti.schema_name, ti.table_name;
|
||||
parameters:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
- name: output_format
|
||||
type: string
|
||||
description: "Optional: Use 'simple' to return table names only or use 'detailed' to return the full information schema."
|
||||
default: "detailed"
|
||||
|
||||
toolsets:
|
||||
cloud-sql-postgres-database-tools:
|
||||
|
||||
@@ -32,168 +32,9 @@ tools:
|
||||
source: mysql-source
|
||||
description: Use this tool to execute SQL.
|
||||
list_tables:
|
||||
kind: mysql-sql
|
||||
kind: mysql-list-tables
|
||||
source: mysql-source
|
||||
description: "Lists detailed schema information (object type, columns, constraints, indexes, triggers, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas."
|
||||
statement: |
|
||||
SELECT
|
||||
T.TABLE_SCHEMA AS schema_name,
|
||||
T.TABLE_NAME AS object_name,
|
||||
CASE
|
||||
WHEN @output_format = 'simple' THEN
|
||||
JSON_OBJECT('name', T.TABLE_NAME)
|
||||
ELSE
|
||||
CONVERT( JSON_OBJECT(
|
||||
'schema_name', T.TABLE_SCHEMA,
|
||||
'object_name', T.TABLE_NAME,
|
||||
'object_type', 'TABLE',
|
||||
'owner', (
|
||||
SELECT
|
||||
IFNULL(U.GRANTEE, 'N/A')
|
||||
FROM
|
||||
INFORMATION_SCHEMA.SCHEMA_PRIVILEGES U
|
||||
WHERE
|
||||
U.TABLE_SCHEMA = T.TABLE_SCHEMA
|
||||
LIMIT 1
|
||||
),
|
||||
'comment', IFNULL(T.TABLE_COMMENT, ''),
|
||||
'columns', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'column_name', C.COLUMN_NAME,
|
||||
'data_type', C.COLUMN_TYPE,
|
||||
'ordinal_position', C.ORDINAL_POSITION,
|
||||
'is_not_nullable', IF(C.IS_NULLABLE = 'NO', TRUE, FALSE),
|
||||
'column_default', C.COLUMN_DEFAULT,
|
||||
'column_comment', IFNULL(C.COLUMN_COMMENT, '')
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.COLUMNS C
|
||||
WHERE
|
||||
C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_NAME = T.TABLE_NAME
|
||||
ORDER BY C.ORDINAL_POSITION
|
||||
),
|
||||
'constraints', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'constraint_name', TC.CONSTRAINT_NAME,
|
||||
'constraint_type',
|
||||
CASE TC.CONSTRAINT_TYPE
|
||||
WHEN 'PRIMARY KEY' THEN 'PRIMARY KEY'
|
||||
WHEN 'FOREIGN KEY' THEN 'FOREIGN KEY'
|
||||
WHEN 'UNIQUE' THEN 'UNIQUE'
|
||||
ELSE TC.CONSTRAINT_TYPE
|
||||
END,
|
||||
'constraint_definition', '',
|
||||
'constraint_columns', (
|
||||
SELECT
|
||||
IFNULL(JSON_ARRAYAGG(KCU.COLUMN_NAME), JSON_ARRAY())
|
||||
FROM
|
||||
INFORMATION_SCHEMA.KEY_COLUMN_USAGE KCU
|
||||
WHERE
|
||||
KCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND KCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND KCU.TABLE_NAME = TC.TABLE_NAME
|
||||
ORDER BY KCU.ORDINAL_POSITION
|
||||
),
|
||||
'foreign_key_referenced_table', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY', RC.REFERENCED_TABLE_NAME, NULL),
|
||||
'foreign_key_referenced_columns', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY',
|
||||
(SELECT IFNULL(JSON_ARRAYAGG(FKCU.REFERENCED_COLUMN_NAME), JSON_ARRAY())
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE FKCU
|
||||
WHERE FKCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND FKCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND FKCU.TABLE_NAME = TC.TABLE_NAME
|
||||
AND FKCU.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
ORDER BY FKCU.ORDINAL_POSITION),
|
||||
NULL
|
||||
)
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC
|
||||
LEFT JOIN
|
||||
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS RC
|
||||
ON TC.CONSTRAINT_SCHEMA = RC.CONSTRAINT_SCHEMA
|
||||
AND TC.CONSTRAINT_NAME = RC.CONSTRAINT_NAME
|
||||
AND TC.TABLE_NAME = RC.TABLE_NAME
|
||||
WHERE
|
||||
TC.TABLE_SCHEMA = T.TABLE_SCHEMA AND TC.TABLE_NAME = T.TABLE_NAME
|
||||
),
|
||||
'indexes', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'index_name', IndexData.INDEX_NAME,
|
||||
'is_unique', IF(IndexData.NON_UNIQUE = 0, TRUE, FALSE),
|
||||
'is_primary', IF(IndexData.INDEX_NAME = 'PRIMARY', TRUE, FALSE),
|
||||
'index_columns', IFNULL(IndexData.INDEX_COLUMNS_ARRAY, JSON_ARRAY())
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM (
|
||||
SELECT
|
||||
S.TABLE_SCHEMA,
|
||||
S.TABLE_NAME,
|
||||
S.INDEX_NAME,
|
||||
MIN(S.NON_UNIQUE) AS NON_UNIQUE, -- Aggregate NON_UNIQUE here to get unique status for the index
|
||||
JSON_ARRAYAGG(S.COLUMN_NAME) AS INDEX_COLUMNS_ARRAY -- Aggregate columns into an array for this index
|
||||
FROM
|
||||
INFORMATION_SCHEMA.STATISTICS S
|
||||
WHERE
|
||||
S.TABLE_SCHEMA = T.TABLE_SCHEMA AND S.TABLE_NAME = T.TABLE_NAME
|
||||
GROUP BY
|
||||
S.TABLE_SCHEMA, S.TABLE_NAME, S.INDEX_NAME
|
||||
) AS IndexData
|
||||
ORDER BY IndexData.INDEX_NAME
|
||||
),
|
||||
'triggers', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'trigger_name', TR.TRIGGER_NAME,
|
||||
'trigger_definition', TR.ACTION_STATEMENT
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TRIGGERS TR
|
||||
WHERE
|
||||
TR.EVENT_OBJECT_SCHEMA = T.TABLE_SCHEMA AND TR.EVENT_OBJECT_TABLE = T.TABLE_NAME
|
||||
ORDER BY TR.TRIGGER_NAME
|
||||
)
|
||||
) USING utf8mb4)
|
||||
END AS object_details
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLES T
|
||||
CROSS JOIN (SELECT @table_names := ?, @output_format := ?) AS variables
|
||||
WHERE
|
||||
T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')
|
||||
AND (NULLIF(TRIM(@table_names), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, @table_names))
|
||||
AND T.TABLE_TYPE = 'BASE TABLE'
|
||||
ORDER BY
|
||||
T.TABLE_SCHEMA, T.TABLE_NAME;
|
||||
parameters:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
default: ""
|
||||
- name: output_format
|
||||
type: string
|
||||
description: "Optional: Use 'simple' to return table names only or use 'detailed' to return the full information schema."
|
||||
default: "detailed"
|
||||
|
||||
toolsets:
|
||||
mysql-database-tools:
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
sources:
|
||||
postgresql-source:
|
||||
kind: postgres
|
||||
@@ -28,93 +29,9 @@ tools:
|
||||
description: Use this tool to execute SQL.
|
||||
|
||||
list_tables:
|
||||
kind: postgres-sql
|
||||
kind: postgres-list-tables
|
||||
source: postgresql-source
|
||||
description: "Lists detailed schema information (object type, columns, constraints, indexes, triggers, owner, comment) as JSON for user-created tables (ordinary or partitioned). Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas."
|
||||
statement: |
|
||||
WITH desired_relkinds AS (
|
||||
SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE'
|
||||
),
|
||||
table_info AS (
|
||||
SELECT
|
||||
t.oid AS table_oid,
|
||||
ns.nspname AS schema_name,
|
||||
t.relname AS table_name,
|
||||
pg_get_userbyid(t.relowner) AS table_owner,
|
||||
obj_description(t.oid, 'pg_class') AS table_comment,
|
||||
t.relkind AS object_kind
|
||||
FROM
|
||||
pg_class t
|
||||
JOIN
|
||||
pg_namespace ns ON ns.oid = t.relnamespace
|
||||
CROSS JOIN desired_relkinds dk
|
||||
WHERE
|
||||
t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p')
|
||||
AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names
|
||||
AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
),
|
||||
columns_info AS (
|
||||
SELECT
|
||||
att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type,
|
||||
att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment
|
||||
FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum
|
||||
JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped
|
||||
),
|
||||
constraints_info AS (
|
||||
SELECT
|
||||
con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition,
|
||||
CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns,
|
||||
NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns
|
||||
FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid
|
||||
),
|
||||
indexes_info AS (
|
||||
SELECT
|
||||
idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition,
|
||||
idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method,
|
||||
(SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns
|
||||
FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid
|
||||
),
|
||||
triggers_info AS (
|
||||
SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state
|
||||
FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal
|
||||
)
|
||||
SELECT
|
||||
ti.schema_name,
|
||||
ti.table_name AS object_name,
|
||||
CASE
|
||||
WHEN $2 = 'simple' THEN
|
||||
-- IF format is 'simple', return basic JSON
|
||||
json_build_object('name', ti.table_name)
|
||||
ELSE
|
||||
json_build_object(
|
||||
'schema_name', ti.schema_name,
|
||||
'object_name', ti.table_name,
|
||||
'object_type', CASE ti.object_kind
|
||||
WHEN 'r' THEN 'TABLE'
|
||||
WHEN 'p' THEN 'PARTITIONED TABLE'
|
||||
ELSE ti.object_kind::text -- Should not happen due to WHERE clause
|
||||
END,
|
||||
'owner', ti.table_owner,
|
||||
'comment', ti.table_comment,
|
||||
'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json),
|
||||
'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json),
|
||||
'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json),
|
||||
'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json)
|
||||
)
|
||||
END AS object_details
|
||||
FROM table_info ti ORDER BY ti.schema_name, ti.table_name;
|
||||
parameters:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
- name: output_format
|
||||
type: string
|
||||
description: "Optional: Use 'simple' to return table names only or use 'detailed' to return the full information schema."
|
||||
default: "detailed"
|
||||
|
||||
toolsets:
|
||||
postgres-database-tools:
|
||||
|
||||
117
internal/sources/alloydbadmin/alloydbadmin.go
Normal file
117
internal/sources/alloydbadmin/alloydbadmin.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package alloydbadmin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
alloydbrestapi "google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const SourceKind string = "alloydb-admin"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
ua, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
fmt.Printf("Error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = nil
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
creds, err := google.FindDefaultCredentials(ctx, alloydbrestapi.CloudPlatformScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
client = oauth2.NewClient(ctx, creds.TokenSource)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: "https://alloydb.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
125
internal/sources/alloydbadmin/alloydbadmin_test.go
Normal file
125
internal/sources/alloydbadmin/alloydbadmin_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package alloydbadmin_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
kind: alloydb-admin
|
||||
project: test-project
|
||||
`,
|
||||
err: "unable to parse source \"my-alloydb-admin-instance\" as \"alloydb-admin\": [2:1] unknown field \"project\"\n 1 | kind: alloydb-admin\n> 2 | project: test-project\n ^\n",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-alloydb-admin-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "missing 'kind' field for source \"my-alloydb-admin-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,8 @@ package bigquery
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -26,6 +28,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
@@ -52,11 +55,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
|
||||
type Config struct {
|
||||
// BigQuery configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Location string `yaml:"location"`
|
||||
AllowedDatasets []string `yaml:"allowedDatasets"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
@@ -84,6 +88,37 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
}
|
||||
}
|
||||
|
||||
allowedDatasets := make(map[string]struct{})
|
||||
// Get full id of allowed datasets and verify they exist.
|
||||
if len(r.AllowedDatasets) > 0 {
|
||||
for _, allowed := range r.AllowedDatasets {
|
||||
var projectID, datasetID, allowedFullID string
|
||||
if strings.Contains(allowed, ".") {
|
||||
parts := strings.Split(allowed, ".")
|
||||
if len(parts) != 2 {
|
||||
return nil, fmt.Errorf("invalid allowedDataset format: %q, expected 'project.dataset' or 'dataset'", allowed)
|
||||
}
|
||||
projectID = parts[0]
|
||||
datasetID = parts[1]
|
||||
allowedFullID = allowed
|
||||
} else {
|
||||
projectID = client.Project()
|
||||
datasetID = allowed
|
||||
allowedFullID = fmt.Sprintf("%s.%s", projectID, datasetID)
|
||||
}
|
||||
|
||||
dataset := client.DatasetInProject(projectID, datasetID)
|
||||
_, err := dataset.Metadata(ctx)
|
||||
if err != nil {
|
||||
if gerr, ok := err.(*googleapi.Error); ok && gerr.Code == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("allowedDataset '%s' not found in project '%s'", datasetID, projectID)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to verify allowedDataset '%s' in project '%s': %w", datasetID, projectID, err)
|
||||
}
|
||||
allowedDatasets[allowedFullID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
@@ -94,6 +129,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
ClientCreator: clientCreator,
|
||||
AllowedDatasets: allowedDatasets,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
@@ -113,6 +149,7 @@ type Source struct {
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
AllowedDatasets map[string]struct{}
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
@@ -153,6 +190,29 @@ func (s *Source) BigQueryClientCreator() BigqueryClientCreator {
|
||||
return s.ClientCreator
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryAllowedDatasets() []string {
|
||||
if len(s.AllowedDatasets) == 0 {
|
||||
return nil
|
||||
}
|
||||
datasets := make([]string, 0, len(s.AllowedDatasets))
|
||||
for d := range s.AllowedDatasets {
|
||||
datasets = append(datasets, d)
|
||||
}
|
||||
return datasets
|
||||
}
|
||||
|
||||
// IsDatasetAllowed checks if a given dataset is accessible based on the source's configuration.
|
||||
func (s *Source) IsDatasetAllowed(projectID, datasetID string) bool {
|
||||
// If the normalized map is empty, it means no restrictions were configured.
|
||||
if len(s.AllowedDatasets) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
targetDataset := fmt.Sprintf("%s.%s", projectID, datasetID)
|
||||
_, ok := s.AllowedDatasets[targetDataset]
|
||||
return ok
|
||||
}
|
||||
|
||||
func initBigQueryConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
|
||||
@@ -69,6 +69,27 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with allowed datasets example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
allowedDatasets:
|
||||
- my_dataset
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
AllowedDatasets: []string{"my_dataset"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
|
||||
117
internal/sources/cloudmonitoring/cloud_monitoring.go
Normal file
117
internal/sources/cloudmonitoring/cloud_monitoring.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package cloudmonitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
monitoring "google.golang.org/api/monitoring/v3"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-monitoring"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Cloud Monitoring Source instance.
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
ua, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = nil
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
creds, err := google.FindDefaultCredentials(ctx, monitoring.MonitoringScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
client = oauth2.NewClient(ctx, creds.TokenSource)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: "https://monitoring.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
135
internal/sources/cloudmonitoring/cloud_monitoring_test.go
Normal file
135
internal/sources/cloudmonitoring/cloud_monitoring_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudmonitoring_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-monitoring-instance": cloudmonitoring.Config{
|
||||
Name: "my-cloud-monitoring-instance",
|
||||
Kind: cloudmonitoring.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
kind: cloud-monitoring
|
||||
project: test-project
|
||||
`,
|
||||
err: `unable to parse source "my-cloud-monitoring-instance" as "cloud-monitoring": [2:1] unknown field "project"
|
||||
1 | kind: cloud-monitoring
|
||||
> 2 | project: test-project
|
||||
^
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-monitoring-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "missing 'kind' field for source \"my-cloud-monitoring-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
117
internal/sources/cloudsqladmin/cloud_sql_admin.go
Normal file
117
internal/sources/cloudsqladmin/cloud_sql_admin.go
Normal file
@@ -0,0 +1,117 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package cloudsqladmin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a CloudSQL Admin Source instance.
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
ua, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = nil
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
creds, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
client = oauth2.NewClient(ctx, creds.TokenSource)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: "https://sqladmin.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
135
internal/sources/cloudsqladmin/cloud_sql_admin_test.go
Normal file
135
internal/sources/cloudsqladmin/cloud_sql_admin_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqladmin_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudSQLAdmin(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-cloud-sql-admin-instance": cloudsqladmin.Config{
|
||||
Name: "my-cloud-sql-admin-instance",
|
||||
Kind: cloudsqladmin.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
kind: cloud-sql-admin
|
||||
project: test-project
|
||||
`,
|
||||
err: `unable to parse source "my-cloud-sql-admin-instance" as "cloud-sql-admin": [2:1] unknown field "project"
|
||||
1 | kind: cloud-sql-admin
|
||||
> 2 | project: test-project
|
||||
^
|
||||
`,
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-cloud-sql-admin-instance:
|
||||
useClientOAuth: true
|
||||
`,
|
||||
err: "missing 'kind' field for source \"my-cloud-sql-admin-instance\"",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -52,7 +52,7 @@ type Config struct {
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Region string `yaml:"region" validate:"required"`
|
||||
Instance string `yaml:"instance" validate:"required"`
|
||||
IPType sources.IPType `yaml:"ipType" validate:"required"`
|
||||
IPType sources.IPType `yaml:"ipType"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package bigqueryanalyzecontribution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const kind string = "bigquery-analyze-contribution"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
UseClientAuthorization() bool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &bigqueryds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{bigqueryds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
inputDataParameter := tools.NewStringParameter("input_data",
|
||||
"The data that contain the test and control data to analyze. Can be a fully qualified BigQuery table ID or a SQL query.")
|
||||
contributionMetricParameter := tools.NewStringParameter("contribution_metric",
|
||||
`The name of the column that contains the metric to analyze.
|
||||
Provides the expression to use to calculate the metric you are analyzing.
|
||||
To calculate a summable metric, the expression must be in the form SUM(metric_column_name),
|
||||
where metric_column_name is a numeric data type.
|
||||
|
||||
To calculate a summable ratio metric, the expression must be in the form
|
||||
SUM(numerator_metric_column_name)/SUM(denominator_metric_column_name),
|
||||
where numerator_metric_column_name and denominator_metric_column_name are numeric data types.
|
||||
|
||||
To calculate a summable by category metric, the expression must be in the form
|
||||
SUM(metric_sum_column_name)/COUNT(DISTINCT categorical_column_name). The summed column must be a numeric data type.
|
||||
The categorical column must have type BOOL, DATE, DATETIME, TIME, TIMESTAMP, STRING, or INT64.`)
|
||||
isTestColParameter := tools.NewStringParameter("is_test_col",
|
||||
"The name of the column that identifies whether a row is in the test or control group.")
|
||||
dimensionIDColsParameter := tools.NewArrayParameterWithRequired("dimension_id_cols",
|
||||
"An array of column names that uniquely identify each dimension.", false, tools.NewStringParameter("dimension_id_col", "A dimension column name."))
|
||||
topKInsightsParameter := tools.NewIntParameterWithDefault("top_k_insights_by_apriori_support", 30,
|
||||
"The number of top insights to return, ranked by apriori support.")
|
||||
pruningMethodParameter := tools.NewStringParameterWithDefault("pruning_method", "PRUNE_REDUNDANT_INSIGHTS",
|
||||
"The method to use for pruning redundant insights. Can be 'NO_PRUNING' or 'PRUNE_REDUNDANT_INSIGHTS'.")
|
||||
|
||||
parameters := tools.Parameters{
|
||||
inputDataParameter,
|
||||
contributionMetricParameter,
|
||||
isTestColParameter,
|
||||
dimensionIDColsParameter,
|
||||
topKInsightsParameter,
|
||||
pruningMethodParameter,
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke runs the contribution analysis.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
inputData, ok := paramsMap["input_data"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to cast input_data parameter %s", paramsMap["input_data"])
|
||||
}
|
||||
|
||||
modelID := fmt.Sprintf("contribution_analysis_model_%s", strings.ReplaceAll(uuid.New().String(), "-", ""))
|
||||
|
||||
var options []string
|
||||
options = append(options, "MODEL_TYPE = 'CONTRIBUTION_ANALYSIS'")
|
||||
options = append(options, fmt.Sprintf("CONTRIBUTION_METRIC = '%s'", paramsMap["contribution_metric"]))
|
||||
options = append(options, fmt.Sprintf("IS_TEST_COL = '%s'", paramsMap["is_test_col"]))
|
||||
|
||||
if val, ok := paramsMap["dimension_id_cols"]; ok {
|
||||
if cols, ok := val.([]any); ok {
|
||||
var strCols []string
|
||||
for _, c := range cols {
|
||||
strCols = append(strCols, fmt.Sprintf("'%s'", c))
|
||||
}
|
||||
options = append(options, fmt.Sprintf("DIMENSION_ID_COLS = [%s]", strings.Join(strCols, ", ")))
|
||||
} else {
|
||||
return nil, fmt.Errorf("unable to cast dimension_id_cols parameter %s", paramsMap["dimension_id_cols"])
|
||||
}
|
||||
}
|
||||
if val, ok := paramsMap["top_k_insights_by_apriori_support"]; ok {
|
||||
options = append(options, fmt.Sprintf("TOP_K_INSIGHTS_BY_APRIORI_SUPPORT = %v", val))
|
||||
}
|
||||
if val, ok := paramsMap["pruning_method"].(string); ok {
|
||||
upperVal := strings.ToUpper(val)
|
||||
if upperVal != "NO_PRUNING" && upperVal != "PRUNE_REDUNDANT_INSIGHTS" {
|
||||
return nil, fmt.Errorf("invalid pruning_method: %s", val)
|
||||
}
|
||||
options = append(options, fmt.Sprintf("PRUNING_METHOD = '%s'", upperVal))
|
||||
}
|
||||
|
||||
var inputDataSource string
|
||||
trimmedUpperInputData := strings.TrimSpace(strings.ToUpper(inputData))
|
||||
if strings.HasPrefix(trimmedUpperInputData, "SELECT") || strings.HasPrefix(trimmedUpperInputData, "WITH") {
|
||||
inputDataSource = fmt.Sprintf("(%s)", inputData)
|
||||
} else {
|
||||
inputDataSource = fmt.Sprintf("SELECT * FROM `%s`", inputData)
|
||||
}
|
||||
|
||||
// Use temp model to skip the clean up at the end. To use TEMP MODEL, queries have to be
|
||||
// in the same BigQuery session.
|
||||
createModelSQL := fmt.Sprintf("CREATE TEMP MODEL %s OPTIONS(%s) AS %s",
|
||||
modelID,
|
||||
strings.Join(options, ", "),
|
||||
inputDataSource,
|
||||
)
|
||||
|
||||
bqClient := t.Client
|
||||
var err error
|
||||
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
bqClient, _, err = t.ClientCreator(tokenStr, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating client from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
createModelQuery := bqClient.Query(createModelSQL)
|
||||
createModelQuery.CreateSession = true
|
||||
createModelJob, err := createModelQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start create model job: %w", err)
|
||||
}
|
||||
|
||||
status, err := createModelJob.Wait(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for create model job: %w", err)
|
||||
}
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("create model job failed: %w", err)
|
||||
}
|
||||
|
||||
if status.Statistics == nil || status.Statistics.SessionInfo == nil || status.Statistics.SessionInfo.SessionID == "" {
|
||||
return nil, fmt.Errorf("failed to create a BigQuery session")
|
||||
}
|
||||
sessionID := status.Statistics.SessionInfo.SessionID
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
|
||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||
getInsightsQuery.QueryConfig.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: sessionID},
|
||||
}
|
||||
|
||||
job, err := getInsightsQuery.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute get insights query: %w", err)
|
||||
}
|
||||
it, err := job.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to read query results: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err := it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate through query results: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for key, value := range row {
|
||||
vMap[key] = value
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
// executes but returns zero rows.
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return t.UseClientOAuth
|
||||
}
|
||||
@@ -4,7 +4,7 @@
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlwaitforoperation_test
|
||||
package bigqueryanalyzecontribution_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -21,10 +21,10 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
cloudsqlwaitforoperation "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigqueryanalyzecontribution"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
func TestParseFromYamlBigQueryAnalyzeContribution(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
@@ -38,26 +38,18 @@ func TestParseFromYaml(t *testing.T) {
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
wait-for-thing:
|
||||
kind: cloud-sql-wait-for-operation
|
||||
source: some-source
|
||||
example_tool:
|
||||
kind: bigquery-analyze-contribution
|
||||
source: my-instance
|
||||
description: some description
|
||||
delay: 1s
|
||||
maxDelay: 5s
|
||||
multiplier: 1.5
|
||||
maxRetries: 5
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"wait-for-thing": cloudsqlwaitforoperation.Config{
|
||||
Name: "wait-for-thing",
|
||||
Kind: "cloud-sql-wait-for-operation",
|
||||
Source: "some-source",
|
||||
"example_tool": bigqueryanalyzecontribution.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "bigquery-analyze-contribution",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
Delay: "1s",
|
||||
MaxDelay: "5s",
|
||||
Multiplier: 1.5,
|
||||
MaxRetries: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -17,6 +17,8 @@ package bigquerylisttableids
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -49,6 +51,8 @@ type compatibleSource interface {
|
||||
BigQueryClientCreator() bigqueryds.BigqueryClientCreator
|
||||
BigQueryProject() string
|
||||
UseClientAuthorization() bool
|
||||
IsDatasetAllowed(projectID, datasetID string) bool
|
||||
BigQueryAllowedDatasets() []string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -84,8 +88,44 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, s.BigQueryProject(), "The Google Cloud project ID containing the dataset.")
|
||||
datasetParameter := tools.NewStringParameter(datasetKey, "The dataset to list table ids.")
|
||||
defaultProjectID := s.BigQueryProject()
|
||||
projectDescription := "The Google Cloud project ID containing the dataset."
|
||||
datasetDescription := "The dataset to list table ids."
|
||||
var datasetParameter tools.Parameter
|
||||
allowedDatasets := s.BigQueryAllowedDatasets()
|
||||
if len(allowedDatasets) > 0 {
|
||||
if len(allowedDatasets) == 1 {
|
||||
parts := strings.Split(allowedDatasets[0], ".")
|
||||
defaultProjectID = parts[0]
|
||||
datasetID := parts[1]
|
||||
projectDescription += fmt.Sprintf(" Must be `%s`.", defaultProjectID)
|
||||
datasetDescription += fmt.Sprintf(" Must be `%s`.", datasetID)
|
||||
datasetParameter = tools.NewStringParameterWithDefault(datasetKey, datasetID, datasetDescription)
|
||||
} else {
|
||||
datasetIDsByProject := make(map[string][]string)
|
||||
for _, ds := range allowedDatasets {
|
||||
parts := strings.Split(ds, ".")
|
||||
project := parts[0]
|
||||
dataset := parts[1]
|
||||
datasetIDsByProject[project] = append(datasetIDsByProject[project], fmt.Sprintf("`%s`", dataset))
|
||||
}
|
||||
|
||||
var datasetDescriptions, projectIDList []string
|
||||
for project, datasets := range datasetIDsByProject {
|
||||
sort.Strings(datasets)
|
||||
projectIDList = append(projectIDList, fmt.Sprintf("`%s`", project))
|
||||
datasetList := strings.Join(datasets, ", ")
|
||||
datasetDescriptions = append(datasetDescriptions, fmt.Sprintf("%s from project `%s`", datasetList, project))
|
||||
}
|
||||
projectDescription += fmt.Sprintf(" Must be one of the following: %s.", strings.Join(projectIDList, ", "))
|
||||
datasetDescription += fmt.Sprintf(" Must be one of the allowed datasets: %s.", strings.Join(datasetDescriptions, "; "))
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
} else {
|
||||
datasetParameter = tools.NewStringParameter(datasetKey, datasetDescription)
|
||||
}
|
||||
projectParameter := tools.NewStringParameterWithDefault(projectKey, defaultProjectID, projectDescription)
|
||||
|
||||
parameters := tools.Parameters{projectParameter, datasetParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
@@ -96,15 +136,16 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
ClientCreator: s.BigQueryClientCreator(),
|
||||
Client: s.BigQueryClient(),
|
||||
IsDatasetAllowed: s.IsDatasetAllowed,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -119,11 +160,12 @@ type Tool struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Client *bigqueryapi.Client
|
||||
ClientCreator bigqueryds.BigqueryClientCreator
|
||||
IsDatasetAllowed func(projectID, datasetID string) bool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
@@ -138,6 +180,10 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", datasetKey)
|
||||
}
|
||||
|
||||
if !t.IsDatasetAllowed(projectId, datasetId) {
|
||||
return nil, fmt.Errorf("access denied to dataset '%s' because it is not in the configured list of allowed datasets for project '%s'", datasetId, projectId)
|
||||
}
|
||||
|
||||
bqClient := t.Client
|
||||
// Initialize new client if using user OAuth token
|
||||
if t.UseClientOAuth {
|
||||
@@ -161,7 +207,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", bqClient.Project(), datasetId, err)
|
||||
return nil, fmt.Errorf("failed to iterate through tables in dataset %s.%s: %w", projectId, datasetId, err)
|
||||
}
|
||||
|
||||
// Remove leading and trailing quotes
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
type compatibleSource interface {
|
||||
ClickHousePool() *sql.DB
|
||||
}
|
||||
|
||||
var compatibleSources = []string{"clickhouse"}
|
||||
|
||||
const listDatabasesKind string = "clickhouse-list-databases"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(listDatabasesKind, newListDatabasesConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", listDatabasesKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newListDatabasesConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return listDatabasesKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", listDatabasesKind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest, _ := tools.ProcessParameters(nil, cfg.Parameters)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: listDatabasesKind,
|
||||
Parameters: cfg.Parameters,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.ClickHousePool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, token tools.AccessToken) (any, error) {
|
||||
// Query to list all databases
|
||||
query := "SHOW DATABASES"
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
var databases []map[string]any
|
||||
for results.Next() {
|
||||
var dbName string
|
||||
err := results.Scan(&dbName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
databases = append(databases, map[string]any{
|
||||
"name": dbName,
|
||||
})
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered by results.Scan: %w", err)
|
||||
}
|
||||
|
||||
return databases, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
func TestListDatabasesConfigToolConfigKind(t *testing.T) {
|
||||
cfg := Config{}
|
||||
if cfg.ToolConfigKind() != listDatabasesKind {
|
||||
t.Errorf("expected %q, got %q", listDatabasesKind, cfg.ToolConfigKind())
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatabasesConfigInitializeMissingSource(t *testing.T) {
|
||||
cfg := Config{
|
||||
Name: "test-list-databases",
|
||||
Kind: listDatabasesKind,
|
||||
Source: "missing-source",
|
||||
Description: "Test list databases tool",
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{}
|
||||
_, err := cfg.Initialize(srcs)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlClickHouseListDatabases(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: clickhouse-list-databases
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": Config{
|
||||
Name: "example_tool",
|
||||
Kind: "clickhouse-list-databases",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatabasesToolParseParams(t *testing.T) {
|
||||
tool := Tool{
|
||||
Parameters: tools.Parameters{},
|
||||
}
|
||||
|
||||
params, err := tool.ParseParams(map[string]any{}, map[string]map[string]any{})
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(params) != 0 {
|
||||
t.Errorf("expected 0 parameters, got %d", len(params))
|
||||
}
|
||||
}
|
||||
177
internal/tools/cloudmonitoring/cloudmonitoring.go
Normal file
177
internal/tools/cloudmonitoring/cloudmonitoring.go
Normal file
@@ -0,0 +1,177 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudmonitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudmonitoringsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "cloud-monitoring-query-prometheus"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudmonitoringsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloudmonitoring`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters internally instead of from the config file.
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameterWithRequired("projectId", "The Id of the Google Cloud project.", true),
|
||||
tools.NewStringParameterWithRequired("query", "The promql query to execute.", true),
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Description: cfg.Description,
|
||||
AllParams: allParameters,
|
||||
BaseURL: s.BaseURL,
|
||||
UserAgent: s.UserAgent,
|
||||
Client: s.Client,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest()},
|
||||
mcpManifest: tools.McpManifest{Name: cfg.Name, Description: cfg.Description, InputSchema: allParameters.McpManifest()},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
UserAgent string
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
projectID, ok := paramsMap["projectId"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("projectId parameter not found or not a string")
|
||||
}
|
||||
query, ok := paramsMap["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/projects/%s/location/global/prometheus/api/v1/query", t.BaseURL, projectID)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := req.URL.Query()
|
||||
q.Add("query", query)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
req.Header.Set("User-Agent", t.UserAgent)
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("request failed: %s, body: %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
if len(body) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal json: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
155
internal/tools/cloudmonitoring/cloudmonitoring_test.go
Normal file
155
internal/tools/cloudmonitoring/cloudmonitoring_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudmonitoring_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cloudmonitoring.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cloud-monitoring-query-prometheus",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "advanced example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
source: my-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cloudmonitoring.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cloud-monitoring-query-prometheus",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlCloudMonitoring(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "Invalid kind",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: invalid-kind
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
err: `unknown tool kind: "invalid-kind"`,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
description: some description
|
||||
`,
|
||||
err: `Key: 'Config.Source' Error:Field validation for 'Source' failed on the 'required' tag`,
|
||||
},
|
||||
{
|
||||
desc: "missing description",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cloud-monitoring-query-prometheus
|
||||
source: my-instance
|
||||
`,
|
||||
err: `Key: 'Config.Description' Error:Field validation for 'Description' failed on the 'required' tag`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error string: got %q, want substring %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,440 +0,0 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlwaitforoperation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-wait-for-operation"
|
||||
|
||||
var cloudSQLConnectionMessageTemplate = `Your Cloud SQL resource is ready.
|
||||
|
||||
To connect, please configure your environment. The method depends on how you are running the toolbox:
|
||||
|
||||
**If running locally via stdio:**
|
||||
Update the MCP server configuration with the following environment variables:
|
||||
` + "```json" + `
|
||||
{
|
||||
"mcpServers": {
|
||||
"cloud-sql-{{.DBType}}": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","cloud-sql-{{.DBType}}","--stdio"],
|
||||
"env": {
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_PROJECT": "{{.Project}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_REGION": "{{.Region}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_INSTANCE": "{{.Instance}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_DATABASE": "{{.Database}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_USER": "<your-user>",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_PASSWORD": "<your-password>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
` + "```" + `
|
||||
|
||||
**If running remotely:**
|
||||
For remote deployments, you will need to set the following environment variables in your deployment configuration:
|
||||
` + "```" + `
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_PROJECT={{.Project}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_REGION={{.Region}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_INSTANCE={{.Instance}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_DATABASE={{.Database}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_USER=<your-user>
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_PASSWORD=<your-password>
|
||||
` + "```" + `
|
||||
|
||||
Please refer to the official documentation for guidance on deploying the toolbox:
|
||||
- Deploying the Toolbox: https://googleapis.github.io/genai-toolbox/how-to/deploy_toolbox/
|
||||
- Deploying on GKE: https://googleapis.github.io/genai-toolbox/how-to/deploy_gke/
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
|
||||
// Polling configuration
|
||||
Delay string `yaml:"delay"`
|
||||
MaxDelay string `yaml:"maxDelay"`
|
||||
Multiplier float64 `yaml:"multiplier"`
|
||||
MaxRetries int `yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the kind of the tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*httpsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind)
|
||||
}
|
||||
|
||||
if s.BaseURL != "https://sqladmin.googleapis.com" && !strings.HasPrefix(s.BaseURL, "http://127.0.0.1") {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: baseUrl must be `https://sqladmin.googleapis.com`", kind)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The project ID"),
|
||||
tools.NewStringParameter("operation", "The operation ID"),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"project", "operation"}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://sqladmin.googleapis.com"
|
||||
}
|
||||
|
||||
var delay time.Duration
|
||||
if cfg.Delay == "" {
|
||||
delay = 3 * time.Second
|
||||
} else {
|
||||
var err error
|
||||
delay, err = time.ParseDuration(cfg.Delay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for delay: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var maxDelay time.Duration
|
||||
if cfg.MaxDelay == "" {
|
||||
maxDelay = 4 * time.Minute
|
||||
} else {
|
||||
var err error
|
||||
maxDelay, err = time.ParseDuration(cfg.MaxDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for maxDelay: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
multiplier := cfg.Multiplier
|
||||
if multiplier == 0 {
|
||||
multiplier = 2.0
|
||||
}
|
||||
|
||||
maxRetries := cfg.MaxRetries
|
||||
if maxRetries == 0 {
|
||||
maxRetries = 10
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
BaseURL: baseURL,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Delay: delay,
|
||||
MaxDelay: maxDelay,
|
||||
Multiplier: multiplier,
|
||||
MaxRetries: maxRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
}
|
||||
operationID, ok := paramsMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("%s/v1/projects/%s/operations/%s", t.BaseURL, project, operationID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
delay := t.Delay
|
||||
maxDelay := t.MaxDelay
|
||||
multiplier := t.Multiplier
|
||||
maxRetries := t.MaxRetries
|
||||
retries := 0
|
||||
|
||||
for retries < maxRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, urlString, nil)
|
||||
|
||||
tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/sqlservice.admin")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating token source: %w", err)
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("error making HTTP request during polling: %s, retrying in %v\n", err, delay)
|
||||
} else {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body during polling: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code during polling: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err == nil {
|
||||
if val, ok := data["status"]; ok {
|
||||
if fmt.Sprintf("%v", val) == "DONE" {
|
||||
if _, ok := data["error"]; ok {
|
||||
return nil, fmt.Errorf("operation finished with error: %s", string(body))
|
||||
}
|
||||
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(data); ok {
|
||||
return msg, nil
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Printf("Operation not complete, retrying in %v\n", delay)
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
delay = time.Duration(float64(delay) * multiplier)
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
retries++
|
||||
}
|
||||
return nil, fmt.Errorf("exceeded max retries waiting for operation")
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
// Manifest returns the tool's manifest.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the tool's MCP manifest.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
targetLink, ok := opResponse["targetLink"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
r := regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
matches := r.FindStringSubmatch(targetLink)
|
||||
if len(matches) < 4 {
|
||||
return "", false
|
||||
}
|
||||
project := matches[1]
|
||||
instance := matches[2]
|
||||
database := matches[3]
|
||||
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), project, instance)
|
||||
if err != nil {
|
||||
fmt.Printf("error fetching instance data: %v\n", err)
|
||||
return "", false
|
||||
}
|
||||
|
||||
region, ok := instanceData["region"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
databaseVersion, ok := instanceData["databaseVersion"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var dbType string
|
||||
if strings.Contains(databaseVersion, "POSTGRES") {
|
||||
dbType = "postgres"
|
||||
} else if strings.Contains(databaseVersion, "MYSQL") {
|
||||
dbType = "mysql"
|
||||
} else if strings.Contains(databaseVersion, "SQLSERVER") {
|
||||
dbType = "mssql"
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tmpl, err := template.New("cloud-sql-connection").Parse(cloudSQLConnectionMessageTemplate)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("template parsing error: %v", err), false
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Project string
|
||||
Region string
|
||||
Instance string
|
||||
DBType string
|
||||
DBTypeUpper string
|
||||
Database string
|
||||
}{
|
||||
Project: project,
|
||||
Region: region,
|
||||
Instance: instance,
|
||||
DBType: dbType,
|
||||
DBTypeUpper: strings.ToUpper(dbType),
|
||||
Database: database,
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
if err := tmpl.Execute(&b, data); err != nil {
|
||||
return fmt.Sprintf("template execution error: %v", err), false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
|
||||
urlString := fmt.Sprintf("%s/v1/projects/%s/instances/%s", t.BaseURL, project, instance)
|
||||
req, _ := http.NewRequest(http.MethodGet, urlString, nil)
|
||||
|
||||
tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/sqlservice.admin")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating token source: %w", err)
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code fetching instance data: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
360
internal/tools/mysql/mysqllisttables/mysqllisttables.go
Normal file
360
internal/tools/mysql/mysqllisttables/mysqllisttables.go
Normal file
@@ -0,0 +1,360 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqllisttables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||
)
|
||||
|
||||
const kind string = "mysql-list-tables"
|
||||
|
||||
const listTablesStatement = `
|
||||
SELECT
|
||||
T.TABLE_SCHEMA AS schema_name,
|
||||
T.TABLE_NAME AS object_name,
|
||||
CASE
|
||||
WHEN @output_format = 'simple' THEN
|
||||
JSON_OBJECT('name', T.TABLE_NAME)
|
||||
ELSE
|
||||
CONVERT(
|
||||
JSON_OBJECT(
|
||||
'schema_name', T.TABLE_SCHEMA,
|
||||
'object_name', T.TABLE_NAME,
|
||||
'object_type', 'TABLE',
|
||||
'owner', (
|
||||
SELECT
|
||||
IFNULL(U.GRANTEE, 'N/A')
|
||||
FROM
|
||||
INFORMATION_SCHEMA.SCHEMA_PRIVILEGES U
|
||||
WHERE
|
||||
U.TABLE_SCHEMA = T.TABLE_SCHEMA
|
||||
LIMIT 1
|
||||
),
|
||||
'comment', IFNULL(T.TABLE_COMMENT, ''),
|
||||
'columns', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'column_name', C.COLUMN_NAME,
|
||||
'data_type', C.COLUMN_TYPE,
|
||||
'ordinal_position', C.ORDINAL_POSITION,
|
||||
'is_not_nullable', IF(C.IS_NULLABLE = 'NO', TRUE, FALSE),
|
||||
'column_default', C.COLUMN_DEFAULT,
|
||||
'column_comment', IFNULL(C.COLUMN_COMMENT, '')
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.COLUMNS C
|
||||
WHERE
|
||||
C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_NAME = T.TABLE_NAME
|
||||
ORDER BY C.ORDINAL_POSITION
|
||||
),
|
||||
'constraints', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'constraint_name', TC.CONSTRAINT_NAME,
|
||||
'constraint_type',
|
||||
CASE TC.CONSTRAINT_TYPE
|
||||
WHEN 'PRIMARY KEY' THEN 'PRIMARY KEY'
|
||||
WHEN 'FOREIGN KEY' THEN 'FOREIGN KEY'
|
||||
WHEN 'UNIQUE' THEN 'UNIQUE'
|
||||
ELSE TC.CONSTRAINT_TYPE
|
||||
END,
|
||||
'constraint_definition', '',
|
||||
'constraint_columns', (
|
||||
SELECT
|
||||
IFNULL(JSON_ARRAYAGG(KCU.COLUMN_NAME), JSON_ARRAY())
|
||||
FROM
|
||||
INFORMATION_SCHEMA.KEY_COLUMN_USAGE KCU
|
||||
WHERE
|
||||
KCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND KCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND KCU.TABLE_NAME = TC.TABLE_NAME
|
||||
ORDER BY KCU.ORDINAL_POSITION
|
||||
),
|
||||
'foreign_key_referenced_table', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY', RC.REFERENCED_TABLE_NAME, NULL),
|
||||
'foreign_key_referenced_columns', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY',
|
||||
(SELECT IFNULL(JSON_ARRAYAGG(FKCU.REFERENCED_COLUMN_NAME), JSON_ARRAY())
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE FKCU
|
||||
WHERE FKCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA
|
||||
AND FKCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME
|
||||
AND FKCU.TABLE_NAME = TC.TABLE_NAME
|
||||
AND FKCU.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
ORDER BY FKCU.ORDINAL_POSITION),
|
||||
NULL
|
||||
)
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC
|
||||
LEFT JOIN
|
||||
INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS RC
|
||||
ON TC.CONSTRAINT_SCHEMA = RC.CONSTRAINT_SCHEMA
|
||||
AND TC.CONSTRAINT_NAME = RC.CONSTRAINT_NAME
|
||||
AND TC.TABLE_NAME = RC.TABLE_NAME
|
||||
WHERE
|
||||
TC.TABLE_SCHEMA = T.TABLE_SCHEMA AND TC.TABLE_NAME = T.TABLE_NAME
|
||||
),
|
||||
'indexes', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'index_name', IndexData.INDEX_NAME,
|
||||
'is_unique', IF(IndexData.NON_UNIQUE = 0, TRUE, FALSE),
|
||||
'is_primary', IF(IndexData.INDEX_NAME = 'PRIMARY', TRUE, FALSE),
|
||||
'index_columns', IFNULL(IndexData.INDEX_COLUMNS_ARRAY, JSON_ARRAY())
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM (
|
||||
SELECT
|
||||
S.TABLE_SCHEMA,
|
||||
S.TABLE_NAME,
|
||||
S.INDEX_NAME,
|
||||
MIN(S.NON_UNIQUE) AS NON_UNIQUE,
|
||||
JSON_ARRAYAGG(S.COLUMN_NAME) AS INDEX_COLUMNS_ARRAY
|
||||
FROM
|
||||
INFORMATION_SCHEMA.STATISTICS S
|
||||
WHERE
|
||||
S.TABLE_SCHEMA = T.TABLE_SCHEMA AND S.TABLE_NAME = T.TABLE_NAME
|
||||
GROUP BY
|
||||
S.TABLE_SCHEMA, S.TABLE_NAME, S.INDEX_NAME
|
||||
) AS IndexData
|
||||
ORDER BY IndexData.INDEX_NAME
|
||||
),
|
||||
'triggers', (
|
||||
SELECT
|
||||
IFNULL(
|
||||
JSON_ARRAYAGG(
|
||||
JSON_OBJECT(
|
||||
'trigger_name', TR.TRIGGER_NAME,
|
||||
'trigger_definition', TR.ACTION_STATEMENT
|
||||
)
|
||||
),
|
||||
JSON_ARRAY()
|
||||
)
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TRIGGERS TR
|
||||
WHERE
|
||||
TR.EVENT_OBJECT_SCHEMA = T.TABLE_SCHEMA AND TR.EVENT_OBJECT_TABLE = T.TABLE_NAME
|
||||
ORDER BY TR.TRIGGER_NAME
|
||||
)
|
||||
)
|
||||
USING utf8mb4)
|
||||
END AS object_details
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLES T
|
||||
CROSS JOIN (SELECT @table_names := ?, @output_format := ?) AS variables
|
||||
WHERE
|
||||
T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')
|
||||
AND (NULLIF(TRIM(@table_names), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, @table_names))
|
||||
AND T.TABLE_TYPE = 'BASE TABLE'
|
||||
ORDER BY
|
||||
T.TABLE_SCHEMA, T.TABLE_NAME;
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
MySQLPool() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &cloudsqlmysql.Source{}
|
||||
var _ compatibleSource = &mysql.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("table_names", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."),
|
||||
tools.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
inputSchema := allParameters.McpManifest()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.MySQLPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
tableNames, ok := paramsMap["table_names"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", tableNames)
|
||||
}
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, listTablesStatement, tableNames, outputFormat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
vMap[name], err = mysqlcommon.ConvertToType(colTypes[i], val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("errors encountered when converting values: %w", err)
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
75
internal/tools/mysql/mysqllisttables/mysqllisttables_test.go
Normal file
75
internal/tools/mysql/mysqllisttables/mysqllisttables_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mysqllisttables_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
mysqllisttables "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqllisttables"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMySQLListTables(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mysql-list-tables
|
||||
source: my-mysql-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mysqllisttables.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "mysql-list-tables",
|
||||
Source: "my-mysql-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
261
internal/tools/postgres/postgreslisttables/postgreslisttables.go
Normal file
261
internal/tools/postgres/postgreslisttables/postgreslisttables.go
Normal file
@@ -0,0 +1,261 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgreslisttables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const kind string = "postgres-list-tables"
|
||||
|
||||
const listTablesStatement = `
|
||||
WITH desired_relkinds AS (
|
||||
SELECT ARRAY['r', 'p']::char[] AS kinds -- Always consider both 'TABLE' and 'PARTITIONED TABLE'
|
||||
),
|
||||
table_info AS (
|
||||
SELECT
|
||||
t.oid AS table_oid,
|
||||
ns.nspname AS schema_name,
|
||||
t.relname AS table_name,
|
||||
pg_get_userbyid(t.relowner) AS table_owner,
|
||||
obj_description(t.oid, 'pg_class') AS table_comment,
|
||||
t.relkind AS object_kind
|
||||
FROM
|
||||
pg_class t
|
||||
JOIN
|
||||
pg_namespace ns ON ns.oid = t.relnamespace
|
||||
CROSS JOIN desired_relkinds dk
|
||||
WHERE
|
||||
t.relkind = ANY(dk.kinds) -- Filter by selected table relkinds ('r', 'p')
|
||||
AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) -- $1 is object_names
|
||||
AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
),
|
||||
columns_info AS (
|
||||
SELECT
|
||||
att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type,
|
||||
att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment
|
||||
FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum
|
||||
JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped
|
||||
),
|
||||
constraints_info AS (
|
||||
SELECT
|
||||
con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition,
|
||||
CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns,
|
||||
NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table,
|
||||
(SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns
|
||||
FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid
|
||||
),
|
||||
indexes_info AS (
|
||||
SELECT
|
||||
idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition,
|
||||
idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method,
|
||||
(SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns
|
||||
FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid
|
||||
),
|
||||
triggers_info AS (
|
||||
SELECT tg.tgrelid AS table_oid, tg.tgname AS trigger_name, pg_get_triggerdef(tg.oid) AS trigger_definition, tg.tgenabled AS trigger_enabled_state
|
||||
FROM pg_trigger tg JOIN table_info ti ON tg.tgrelid = ti.table_oid WHERE NOT tg.tgisinternal
|
||||
)
|
||||
SELECT
|
||||
ti.schema_name,
|
||||
ti.table_name AS object_name,
|
||||
CASE
|
||||
WHEN $2 = 'simple' THEN
|
||||
-- IF format is 'simple', return basic JSON
|
||||
json_build_object('name', ti.table_name)
|
||||
ELSE
|
||||
json_build_object(
|
||||
'schema_name', ti.schema_name,
|
||||
'object_name', ti.table_name,
|
||||
'object_type', CASE ti.object_kind
|
||||
WHEN 'r' THEN 'TABLE'
|
||||
WHEN 'p' THEN 'PARTITIONED TABLE'
|
||||
ELSE ti.object_kind::text -- Should not happen due to WHERE clause
|
||||
END,
|
||||
'owner', ti.table_owner,
|
||||
'comment', ti.table_comment,
|
||||
'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json),
|
||||
'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json),
|
||||
'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json),
|
||||
'triggers', COALESCE((SELECT json_agg(json_build_object('trigger_name',tri.trigger_name,'trigger_definition',tri.trigger_definition,'trigger_enabled_state',tri.trigger_enabled_state)) FROM triggers_info tri WHERE tri.table_oid = ti.table_oid), '[]'::json)
|
||||
)
|
||||
END AS object_details
|
||||
FROM table_info ti ORDER BY ti.schema_name, ti.table_name;
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
var _ compatibleSource = &cloudsqlpg.Source{}
|
||||
var _ compatibleSource = &postgres.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."),
|
||||
tools.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
inputSchema := allParameters.McpManifest()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
AllParams: allParameters,
|
||||
Pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
tableNames, ok := paramsMap["table_names"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string")
|
||||
}
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(ctx, listTablesStatement, tableNames, outputFormat)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
var out []map[string]any
|
||||
|
||||
for results.Next() {
|
||||
values, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
rowMap := make(map[string]any)
|
||||
for i, field := range fields {
|
||||
rowMap[string(field.Name)] = values[i]
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error reading query results: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgreslisttables_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
postgreslisttables "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
)
|
||||
|
||||
func TestParseFromYamlPostgresListTables(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-list-tables
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgreslisttables.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-list-tables",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -112,6 +113,12 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
tableNameAnalyzeContribution := fmt.Sprintf("`%s.%s.analyze_contribution_table_%s`",
|
||||
BigqueryProject,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||
@@ -132,6 +139,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
teardownTable4 := setupBigQueryTable(t, ctx, client, createForecastTableStmt, insertForecastTableStmt, datasetName, tableNameForecast, forecastTestParams)
|
||||
defer teardownTable4(t)
|
||||
|
||||
// set up data for analyze contribution tool
|
||||
createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, analyzeContributionTestParams := getBigQueryAnalyzeContributionToolInfo(tableNameAnalyzeContribution)
|
||||
teardownTable5 := setupBigQueryTable(t, ctx, client, createAnalyzeContributionTableStmt, insertAnalyzeContributionTableStmt, datasetName, tableNameAnalyzeContribution, analyzeContributionTestParams)
|
||||
defer teardownTable5(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||
@@ -181,6 +193,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
|
||||
runBigQueryExecuteSqlToolInvokeDryRunTest(t, datasetName)
|
||||
runBigQueryForecastToolInvokeTest(t, tableNameForecast)
|
||||
runBigQueryAnalyzeContributionToolInvokeTest(t, tableNameAnalyzeContribution)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
@@ -189,6 +202,102 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryConversationalAnalyticsInvokeTest(t, datasetName, tableName, dataInsightsWant)
|
||||
}
|
||||
|
||||
func TestBigQueryToolWithDatasetRestriction(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, err := initBigQueryConnection(BigqueryProject)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create BigQuery client: %s", err)
|
||||
}
|
||||
|
||||
// Create two datasets, one allowed, one not.
|
||||
baseName := strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
allowedDatasetName1 := fmt.Sprintf("allowed_dataset_1_%s", baseName)
|
||||
allowedDatasetName2 := fmt.Sprintf("allowed_dataset_2_%s", baseName)
|
||||
disallowedDatasetName := fmt.Sprintf("disallowed_dataset_%s", baseName)
|
||||
allowedTableName1 := "allowed_table_1"
|
||||
allowedTableName2 := "allowed_table_2"
|
||||
disallowedTableName := "disallowed_table"
|
||||
allowedForecastTableName1 := "allowed_forecast_table_1"
|
||||
allowedForecastTableName2 := "allowed_forecast_table_2"
|
||||
disallowedForecastTableName := "disallowed_forecast_table"
|
||||
|
||||
// Setup allowed table
|
||||
allowedTableNameParam1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedTableName1)
|
||||
createAllowedTableStmt1 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam1)
|
||||
teardownAllowed1 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt1, "", allowedDatasetName1, allowedTableNameParam1, nil)
|
||||
defer teardownAllowed1(t)
|
||||
|
||||
allowedTableNameParam2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedTableName2)
|
||||
createAllowedTableStmt2 := fmt.Sprintf("CREATE TABLE %s (id INT64)", allowedTableNameParam2)
|
||||
teardownAllowed2 := setupBigQueryTable(t, ctx, client, createAllowedTableStmt2, "", allowedDatasetName2, allowedTableNameParam2, nil)
|
||||
defer teardownAllowed2(t)
|
||||
|
||||
// Setup allowed forecast table
|
||||
allowedForecastTableFullName1 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName1, allowedForecastTableName1)
|
||||
createForecastStmt1, insertForecastStmt1, forecastParams1 := getBigQueryForecastToolInfo(allowedForecastTableFullName1)
|
||||
teardownAllowedForecast1 := setupBigQueryTable(t, ctx, client, createForecastStmt1, insertForecastStmt1, allowedDatasetName1, allowedForecastTableFullName1, forecastParams1)
|
||||
defer teardownAllowedForecast1(t)
|
||||
|
||||
allowedForecastTableFullName2 := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, allowedDatasetName2, allowedForecastTableName2)
|
||||
createForecastStmt2, insertForecastStmt2, forecastParams2 := getBigQueryForecastToolInfo(allowedForecastTableFullName2)
|
||||
teardownAllowedForecast2 := setupBigQueryTable(t, ctx, client, createForecastStmt2, insertForecastStmt2, allowedDatasetName2, allowedForecastTableFullName2, forecastParams2)
|
||||
defer teardownAllowedForecast2(t)
|
||||
|
||||
// Setup disallowed table
|
||||
disallowedTableNameParam := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedTableName)
|
||||
createDisallowedTableStmt := fmt.Sprintf("CREATE TABLE %s (id INT64)", disallowedTableNameParam)
|
||||
teardownDisallowed := setupBigQueryTable(t, ctx, client, createDisallowedTableStmt, "", disallowedDatasetName, disallowedTableNameParam, nil)
|
||||
defer teardownDisallowed(t)
|
||||
|
||||
// Setup disallowed forecast table
|
||||
disallowedForecastTableFullName := fmt.Sprintf("`%s.%s.%s`", BigqueryProject, disallowedDatasetName, disallowedForecastTableName)
|
||||
createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedForecastParams := getBigQueryForecastToolInfo(disallowedForecastTableFullName)
|
||||
teardownDisallowedForecast := setupBigQueryTable(t, ctx, client, createDisallowedForecastStmt, insertDisallowedForecastStmt, disallowedDatasetName, disallowedForecastTableFullName, disallowedForecastParams)
|
||||
defer teardownDisallowedForecast(t)
|
||||
|
||||
// Configure source with dataset restriction.
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
sourceConfig["allowedDatasets"] = []string{allowedDatasetName1, allowedDatasetName2}
|
||||
|
||||
// Configure tool
|
||||
toolsConfig := map[string]any{
|
||||
"list-table-ids-restricted": map[string]any{
|
||||
"kind": "bigquery-list-table-ids",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to list table within a dataset",
|
||||
},
|
||||
}
|
||||
|
||||
// Create config file
|
||||
config := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": toolsConfig,
|
||||
}
|
||||
|
||||
// Start server
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, config)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName1, disallowedDatasetName, allowedTableName1, allowedForecastTableName1)
|
||||
runListTableIdsWithRestriction(t, allowedDatasetName2, disallowedDatasetName, allowedTableName2, allowedForecastTableName2)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
@@ -244,7 +353,7 @@ func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryap
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (ts TIMESTAMP, data FLOAT64, id STRING);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (ts, data, id) VALUES
|
||||
INSERT INTO %s (ts, data, id) VALUES
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?),
|
||||
(?, ?, ?), (?, ?, ?), (?, ?, ?);`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
@@ -258,6 +367,26 @@ func getBigQueryForecastToolInfo(tableName string) (string, string, []bigqueryap
|
||||
return createStatement, insertStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryAnalyzeContributionToolInfo returns statements and params for the analyze-contribution tool.
|
||||
func getBigQueryAnalyzeContributionToolInfo(tableName string) (string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (dim1 STRING, dim2 STRING, is_test BOOL, metric FLOAT64);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (dim1, dim2, is_test, metric) VALUES
|
||||
(?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?), (?, ?, ?, ?);`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: "a"}, {Value: "x"}, {Value: true}, {Value: 100.0},
|
||||
{Value: "a"}, {Value: "x"}, {Value: false}, {Value: 110.0},
|
||||
{Value: "a"}, {Value: "y"}, {Value: true}, {Value: 120.0},
|
||||
{Value: "a"}, {Value: "y"}, {Value: false}, {Value: 100.0},
|
||||
{Value: "b"}, {Value: "x"}, {Value: true}, {Value: 40.0},
|
||||
{Value: "b"}, {Value: "x"}, {Value: false}, {Value: 100.0},
|
||||
{Value: "b"}, {Value: "y"}, {Value: true}, {Value: 60.0},
|
||||
{Value: "b"}, {Value: "y"}, {Value: false}, {Value: 60.0},
|
||||
}
|
||||
return createStatement, insertStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
@@ -295,19 +424,21 @@ func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.C
|
||||
t.Fatalf("Create table job for %s failed: %v", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
insertQuery := client.Query(insertStatement)
|
||||
insertQuery.Parameters = params
|
||||
insertJob, err := insertQuery.Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start insert job for %s: %v", tableName, err)
|
||||
}
|
||||
insertStatus, err := insertJob.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err)
|
||||
}
|
||||
if err := insertStatus.Err(); err != nil {
|
||||
t.Fatalf("Insert job for %s failed: %v", tableName, err)
|
||||
if len(params) > 0 {
|
||||
// Insert test data
|
||||
insertQuery := client.Query(insertStatement)
|
||||
insertQuery.Parameters = params
|
||||
insertJob, err := insertQuery.Run(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start insert job for %s: %v", tableName, err)
|
||||
}
|
||||
insertStatus, err := insertJob.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to wait for insert job for %s: %v", tableName, err)
|
||||
}
|
||||
if err := insertStatus.Err(); err != nil {
|
||||
t.Fatalf("Insert job for %s failed: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
@@ -383,6 +514,24 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
"source": "my-client-auth-source",
|
||||
"description": "Tool to forecast time series data with auth.",
|
||||
}
|
||||
tools["my-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution.",
|
||||
}
|
||||
tools["my-auth-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to analyze contribution with auth.",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
tools["my-client-auth-analyze-contribution-tool"] = map[string]any{
|
||||
"kind": "bigquery-analyze-contribution",
|
||||
"source": "my-client-auth-source",
|
||||
"description": "Tool to analyze contribution with auth.",
|
||||
}
|
||||
tools["my-list-dataset-ids-tool"] = map[string]any{
|
||||
"kind": "bigquery-list-dataset-ids",
|
||||
"source": "my-instance",
|
||||
@@ -952,6 +1101,127 @@ func runBigQueryForecastToolInvokeTest(t *testing.T, tableName string) {
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryAnalyzeContributionToolInvokeTest(t *testing.T, tableName string) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
// Get access token
|
||||
accessToken, err := sources.GetIAMAccessToken(t.Context())
|
||||
if err != nil {
|
||||
t.Fatalf("error getting access token from ADC: %s", err)
|
||||
}
|
||||
accessToken = "Bearer " + accessToken
|
||||
|
||||
dataTable := strings.ReplaceAll(tableName, "`", "")
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-analyze-contribution-tool without required params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s"}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-analyze-contribution-tool with table",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-analyze-contribution-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-auth-analyze-contribution-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"Authorization": accessToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
want: `"relative_difference"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool without auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
|
||||
name: "Invoke my-client-auth-analyze-contribution-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-client-auth-analyze-contribution-tool/invoke",
|
||||
requestHeader: map[string]string{"Authorization": "Bearer invalid-token"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"input_data": "%s", "contribution_metric": "SUM(metric)", "is_test_col": "is_test", "dimension_id_cols": ["dim1", "dim2"]}`, dataTable))),
|
||||
isErr: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, tc.want) {
|
||||
t.Fatalf("expected %q to contain %q, but it did not", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
@@ -1731,3 +2001,86 @@ func runBigQueryConversationalAnalyticsInvokeTest(t *testing.T, datasetName, tab
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runListTableIdsWithRestriction(t *testing.T, allowedDatasetName, disallowedDatasetName string, allowedTableNames ...string) {
|
||||
sort.Strings(allowedTableNames)
|
||||
var quotedNames []string
|
||||
for _, name := range allowedTableNames {
|
||||
quotedNames = append(quotedNames, fmt.Sprintf(`"%s"`, name))
|
||||
}
|
||||
wantResult := fmt.Sprintf(`[%s]`, strings.Join(quotedNames, ","))
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
dataset string
|
||||
wantStatusCode int
|
||||
wantInResult string
|
||||
wantInError string
|
||||
}{
|
||||
{
|
||||
name: "invoke on allowed dataset",
|
||||
dataset: allowedDatasetName,
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantInResult: wantResult,
|
||||
},
|
||||
{
|
||||
name: "invoke on disallowed dataset",
|
||||
dataset: disallowedDatasetName,
|
||||
wantStatusCode: http.StatusBadRequest, // Or the specific error code returned
|
||||
wantInError: fmt.Sprintf("access denied to dataset '%s'", disallowedDatasetName),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := bytes.NewBuffer([]byte(fmt.Sprintf(`{"dataset":"%s"}`, tc.dataset)))
|
||||
req, err := http.NewRequest(http.MethodPost, "http://127.0.0.1:5000/api/tool/list-table-ids-restricted/invoke", body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantInResult != "" {
|
||||
var respBody map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
got, ok := respBody["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
var gotSlice []string
|
||||
if err := json.Unmarshal([]byte(got), &gotSlice); err != nil {
|
||||
t.Fatalf("error unmarshalling result: %v", err)
|
||||
}
|
||||
sort.Strings(gotSlice)
|
||||
sortedGotBytes, err := json.Marshal(gotSlice)
|
||||
if err != nil {
|
||||
t.Fatalf("error marshalling sorted result: %v", err)
|
||||
}
|
||||
|
||||
if string(sortedGotBytes) != tc.wantInResult {
|
||||
t.Errorf("unexpected result: got %q, want %q", string(sortedGotBytes), tc.wantInResult)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.wantInError != "" {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if !strings.Contains(string(bodyBytes), tc.wantInError) {
|
||||
t.Errorf("unexpected error message: got %q, want to contain %q", string(bodyBytes), tc.wantInError)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
clickhouseexecutesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
clickhouselistdatabases "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
clickhousesql "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
@@ -1012,3 +1013,103 @@ func setupClickHouseSQLTable(t *testing.T, ctx context.Context, pool *sql.DB, cr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClickHouseListDatabasesTool(t *testing.T) {
|
||||
_ = getClickHouseVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
pool, err := initClickHouseConnectionPool(ClickHouseHost, ClickHousePort, ClickHouseUser, ClickHousePass, ClickHouseDatabase, ClickHouseProtocol)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create ClickHouse connection pool: %s", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
// Create a test database
|
||||
testDBName := "test_list_db_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:8]
|
||||
_, err = pool.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s", testDBName))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _ = pool.ExecContext(ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s", testDBName))
|
||||
}()
|
||||
|
||||
t.Run("ListDatabases", func(t *testing.T) {
|
||||
toolConfig := clickhouselistdatabases.Config{
|
||||
Name: "test-list-databases",
|
||||
Kind: "clickhouse-list-databases",
|
||||
Source: "test-clickhouse",
|
||||
Description: "Test listing databases",
|
||||
}
|
||||
|
||||
source := createMockSource(t, pool)
|
||||
sourcesMap := map[string]sources.Source{
|
||||
"test-clickhouse": source,
|
||||
}
|
||||
|
||||
tool, err := toolConfig.Initialize(sourcesMap)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize tool: %v", err)
|
||||
}
|
||||
|
||||
params := tools.ParamValues{}
|
||||
|
||||
result, err := tool.Invoke(ctx, params, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to list databases: %v", err)
|
||||
}
|
||||
|
||||
databases, ok := result.([]map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected result to be []map[string]any, got %T", result)
|
||||
}
|
||||
|
||||
// Should contain at least the default database and our test database - system and default
|
||||
if len(databases) < 2 {
|
||||
t.Errorf("Expected at least 2 databases, got %d", len(databases))
|
||||
}
|
||||
|
||||
found := false
|
||||
foundDefault := false
|
||||
for _, db := range databases {
|
||||
if name, ok := db["name"].(string); ok {
|
||||
if name == testDBName {
|
||||
found = true
|
||||
}
|
||||
if name == "default" || name == "system" {
|
||||
foundDefault = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
t.Errorf("Test database %s not found in list", testDBName)
|
||||
}
|
||||
if !foundDefault {
|
||||
t.Errorf("Default/system database not found in list")
|
||||
}
|
||||
|
||||
t.Logf("Successfully listed %d databases", len(databases))
|
||||
})
|
||||
|
||||
t.Run("ListDatabasesWithInvalidSource", func(t *testing.T) {
|
||||
toolConfig := clickhouselistdatabases.Config{
|
||||
Name: "test-invalid-source",
|
||||
Kind: "clickhouse-list-databases",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test with invalid source",
|
||||
}
|
||||
|
||||
sourcesMap := map[string]sources.Source{}
|
||||
|
||||
_, err := toolConfig.Initialize(sourcesMap)
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-existent source, got nil")
|
||||
} else {
|
||||
t.Logf("Got expected error for invalid source: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Logf("✅ clickhouse-list-databases tool tests completed successfully")
|
||||
}
|
||||
|
||||
113
tests/cloudmonitoring/cloud_monitoring_integration_test.go
Normal file
113
tests/cloudmonitoring/cloud_monitoring_integration_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudmonitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
)
|
||||
|
||||
func TestTool_Invoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock the monitoring server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/projects/test-project/location/global/prometheus/api/v1/query" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
query := r.URL.Query().Get("query")
|
||||
if query != "up" {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
fmt.Fprintln(w, `{"status":"success","data":{"resultType":"vector","result":[]}}`)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a new observability tool
|
||||
tool := &cloudmonitoring.Tool{
|
||||
Name: "test-cloudmonitoring",
|
||||
Kind: "cloud-monitoring-query-prometheus",
|
||||
Description: "Test Cloudmonitoring Tool",
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
}
|
||||
|
||||
// Define the test parameters
|
||||
params := tools.ParamValues{
|
||||
{Name: "projectId", Value: "test-project"},
|
||||
{Name: "query", Value: "up"},
|
||||
}
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(context.Background(), params, "")
|
||||
if err != nil {
|
||||
t.Fatalf("Invoke() error = %v", err)
|
||||
}
|
||||
|
||||
// Check the result
|
||||
expected := map[string]any{
|
||||
"status": "success",
|
||||
"data": map[string]any{
|
||||
"resultType": "vector",
|
||||
"result": []any{},
|
||||
},
|
||||
}
|
||||
if diff := cmp.Diff(expected, result); diff != "" {
|
||||
t.Errorf("Invoke() result mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTool_Invoke_Error(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock the monitoring server to return an error
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Create a new observability tool
|
||||
tool := &cloudmonitoring.Tool{
|
||||
Name: "test-cloudmonitoring",
|
||||
Kind: "clou-monitoring-query-prometheus",
|
||||
Description: "Test Cloudmonitoring Tool",
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
}
|
||||
|
||||
// Define the test parameters
|
||||
params := tools.ParamValues{
|
||||
{Name: "projectId", Value: "test-project"},
|
||||
{Name: "query", Value: "up"},
|
||||
}
|
||||
|
||||
// Invoke the tool
|
||||
_, err := tool.Invoke(context.Background(), params, "")
|
||||
if err == nil {
|
||||
t.Fatal("Invoke() error = nil, want error")
|
||||
}
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudsqlWaitToolKind = "cloud-sql-wait-for-operation"
|
||||
)
|
||||
|
||||
type cloudsqlOperation struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
TargetLink string `json:"targetLink"`
|
||||
OperationType string `json:"operationType"`
|
||||
Error *struct {
|
||||
Errors []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"errors"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type cloudsqlInstance struct {
|
||||
Region string `json:"region"`
|
||||
DatabaseVersion string `json:"databaseVersion"`
|
||||
}
|
||||
|
||||
type cloudsqlHandler struct {
|
||||
mu sync.Mutex
|
||||
operations map[string]*cloudsqlOperation
|
||||
instances map[string]*cloudsqlInstance
|
||||
}
|
||||
|
||||
func (h *cloudsqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if match, _ := regexp.MatchString("/v1/projects/p1/operations/.*", r.URL.Path); match {
|
||||
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
|
||||
opName := parts[len(parts)-1]
|
||||
|
||||
op, ok := h.operations[opName]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if op.Status != "DONE" {
|
||||
op.Status = "DONE"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(op); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else if match, _ := regexp.MatchString("/v1/projects/p1/instances/.*", r.URL.Path); match {
|
||||
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
|
||||
instanceName := parts[len(parts)-1]
|
||||
|
||||
instance, ok := h.instances[instanceName]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudSQLWaitToolEndpoints(t *testing.T) {
|
||||
h := &cloudsqlHandler{
|
||||
operations: map[string]*cloudsqlOperation{
|
||||
"op1": {Name: "op1", Status: "PENDING", OperationType: "CREATE_DATABASE"},
|
||||
"op2": {Name: "op2", Status: "PENDING", OperationType: "CREATE_DATABASE", Error: &struct {
|
||||
Errors []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"errors"`
|
||||
}{
|
||||
Errors: []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
{Code: "ERROR_CODE", Message: "failed"},
|
||||
},
|
||||
}},
|
||||
"op3": {Name: "op3", Status: "PENDING", OperationType: "CREATE"},
|
||||
},
|
||||
instances: map[string]*cloudsqlInstance{
|
||||
"i1": {Region: "r1", DatabaseVersion: "POSTGRES_13"},
|
||||
},
|
||||
}
|
||||
server := httptest.NewServer(h)
|
||||
defer server.Close()
|
||||
|
||||
h.operations["op1"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i1/databases/d1", server.URL)
|
||||
h.operations["op2"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i2/databases/d2", server.URL)
|
||||
h.operations["op3"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i1", server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
toolsFile := getCloudSQLWaitToolsConfig(server.URL)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
toolName string
|
||||
body string
|
||||
want string
|
||||
expectError bool
|
||||
wantSubstring bool
|
||||
}{
|
||||
{
|
||||
name: "successful operation",
|
||||
toolName: "wait-for-op1",
|
||||
body: `{"project": "p1", "operation": "op1"}`,
|
||||
want: "Your Cloud SQL resource is ready",
|
||||
wantSubstring: true,
|
||||
},
|
||||
{
|
||||
name: "failed operation",
|
||||
toolName: "wait-for-op2",
|
||||
body: `{"project": "p1", "operation": "op2"}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "non-database create operation",
|
||||
toolName: "wait-for-op3",
|
||||
body: `{"project": "p1", "operation": "op3"}`,
|
||||
want: `{"name":"op3","status":"DONE","targetLink":"` + h.operations["op3"].TargetLink + `","operationType":"CREATE"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
|
||||
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.expectError {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected error but got status 200")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantSubstring {
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Contains([]byte(result.Result), []byte(tc.want)) {
|
||||
t.Fatalf("unexpected result: got %q, want substring %q", result.Result, tc.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
var tempString string
|
||||
if err := json.Unmarshal([]byte(result.Result), &tempString); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON string: %v", err)
|
||||
}
|
||||
|
||||
var got, want map[string]any
|
||||
if err := json.Unmarshal([]byte(tempString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal inner JSON object: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal want: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected result: got %+v, want %+v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getCloudSQLWaitToolsConfig(baseURL string) map[string]any {
|
||||
return map[string]any{
|
||||
"sources": map[string]any{
|
||||
"test-source": map[string]any{
|
||||
"kind": "http",
|
||||
"baseUrl": baseURL,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"wait-for-op1": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op1",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
"wait-for-op2": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op2",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
"wait-for-op3": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op3",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
114
tests/common.go
114
tests/common.go
@@ -23,6 +23,11 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
@@ -645,3 +650,112 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
}
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
// TestCloudSQLMySQL_IPTypeParsingFromYAML verifies the IPType field parsing from YAML
|
||||
// for the cloud-sql-mysql source, mimicking the structure of tests in cloudsql_mysql_test.go.
|
||||
func TestCloudSQLMySQL_IPTypeParsingFromYAML(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "IPType Defaulting to Public",
|
||||
in: `
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public", // Default value
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPType Explicit Public",
|
||||
in: `
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: Public
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPType Explicit Private",
|
||||
in: `
|
||||
sources:
|
||||
my-mysql-instance:
|
||||
kind: cloud-sql-mysql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
ipType: private
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-mysql-instance": cloudsqlmysql.Config{
|
||||
Name: "my-mysql-instance",
|
||||
Kind: cloudsqlmysql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPType: "private",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: diff (-want +got):\n%s", cmp.Diff(tc.want, got.Sources))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,15 +15,21 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
@@ -32,6 +38,7 @@ import (
|
||||
var (
|
||||
MySQLSourceKind = "mysql"
|
||||
MySQLToolKind = "mysql-sql"
|
||||
MySQLListTablesToolKind = "mysql-list-tables"
|
||||
MySQLDatabase = os.Getenv("MYSQL_DATABASE")
|
||||
MySQLHost = os.Getenv("MYSQL_HOST")
|
||||
MySQLPort = os.Getenv("MYSQL_PORT")
|
||||
@@ -63,6 +70,20 @@ func getMySQLVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["list_tables"] = map[string]any{
|
||||
"kind": MySQLListTablesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables in the database.",
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
// Copied over from mysql.go
|
||||
func initMySQLConnectionPool(host, port, user, pass, dbname string) (*sql.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, pass, host, port, dbname)
|
||||
@@ -108,6 +129,8 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
toolsFile = addPrebuiltToolConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -131,4 +154,182 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run specific MySQL tool tests
|
||||
runMySQLListTablesTest(t, tableNameParam, tableNameAuth)
|
||||
}
|
||||
|
||||
func runMySQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
|
||||
type tableInfo struct {
|
||||
ObjectName string `json:"object_name"`
|
||||
SchemaName string `json:"schema_name"`
|
||||
ObjectDetails string `json:"object_details"`
|
||||
}
|
||||
|
||||
type column struct {
|
||||
DataType string `json:"data_type"`
|
||||
ColumnName string `json:"column_name"`
|
||||
ColumnComment string `json:"column_comment"`
|
||||
ColumnDefault any `json:"column_default"`
|
||||
IsNotNullable int `json:"is_not_nullable"`
|
||||
OrdinalPosition int `json:"ordinal_position"`
|
||||
}
|
||||
|
||||
type objectDetails struct {
|
||||
Owner any `json:"owner"`
|
||||
Columns []column `json:"columns"`
|
||||
Comment string `json:"comment"`
|
||||
Indexes []any `json:"indexes"`
|
||||
Triggers []any `json:"triggers"`
|
||||
Constraints []any `json:"constraints"`
|
||||
ObjectName string `json:"object_name"`
|
||||
ObjectType string `json:"object_type"`
|
||||
SchemaName string `json:"schema_name"`
|
||||
}
|
||||
|
||||
paramTableWant := objectDetails{
|
||||
ObjectName: tableNameParam,
|
||||
SchemaName: MySQLDatabase,
|
||||
ObjectType: "TABLE",
|
||||
Columns: []column{
|
||||
{DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1},
|
||||
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2},
|
||||
},
|
||||
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}},
|
||||
Triggers: []any{},
|
||||
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
||||
}
|
||||
|
||||
authTableWant := objectDetails{
|
||||
ObjectName: tableNameAuth,
|
||||
SchemaName: MySQLDatabase,
|
||||
ObjectType: "TABLE",
|
||||
Columns: []column{
|
||||
{DataType: "int", ColumnName: "id", IsNotNullable: 1, OrdinalPosition: 1},
|
||||
{DataType: "varchar(255)", ColumnName: "name", OrdinalPosition: 2},
|
||||
{DataType: "varchar(255)", ColumnName: "email", OrdinalPosition: 3},
|
||||
},
|
||||
Indexes: []any{map[string]any{"index_columns": []any{"id"}, "index_name": "PRIMARY", "is_primary": float64(1), "is_unique": float64(1)}},
|
||||
Triggers: []any{},
|
||||
Constraints: []any{map[string]any{"constraint_columns": []any{"id"}, "constraint_name": "PRIMARY", "constraint_type": "PRIMARY KEY", "foreign_key_referenced_columns": any(nil), "foreign_key_referenced_table": any(nil), "constraint_definition": ""}},
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want any
|
||||
isSimple bool
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables simple output",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []map[string]any{{"name": tableNameAuth}},
|
||||
isSimple: true,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with multiple table names",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant, paramTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with one existing and one non-existent table",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameAuth)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: []objectDetails{authTableWant},
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
requestBody: bytes.NewBufferString(`{"table_names": "non_existent_table"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_tables/invoke"
|
||||
req, err := http.NewRequest(http.MethodPost, api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %v", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(body))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct{ Result json.RawMessage `json:"result"` }
|
||||
if err := json.NewDecoder(resp.Body).Decode(&bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got any
|
||||
if tc.isSimple {
|
||||
var tables []tableInfo
|
||||
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
||||
}
|
||||
var details []map[string]any
|
||||
for _, table := range tables {
|
||||
var d map[string]any
|
||||
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
details = append(details, d)
|
||||
}
|
||||
got = details
|
||||
} else {
|
||||
if resultString == "null" {
|
||||
got = nil
|
||||
} else {
|
||||
var tables []tableInfo
|
||||
if err := json.Unmarshal([]byte(resultString), &tables); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON array into []tableInfo: %v", err)
|
||||
}
|
||||
var details []objectDetails
|
||||
for _, table := range tables {
|
||||
var d objectDetails
|
||||
if err := json.Unmarshal([]byte(table.ObjectDetails), &d); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested ObjectDetails string: %v", err)
|
||||
}
|
||||
details = append(details, d)
|
||||
}
|
||||
got = details
|
||||
}
|
||||
}
|
||||
|
||||
opts := []cmp.Option{
|
||||
cmpopts.SortSlices(func(a, b objectDetails) bool { return a.ObjectName < b.ObjectName }),
|
||||
cmpopts.SortSlices(func(a, b column) bool { return a.ColumnName < b.ColumnName }),
|
||||
cmpopts.SortSlices(func(a, b map[string]any) bool { return a["name"].(string) < b["name"].(string) }),
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, got, opts...); diff != "" {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,11 +15,17 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -31,13 +37,14 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
PostgresSourceKind = "postgres"
|
||||
PostgresToolKind = "postgres-sql"
|
||||
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
|
||||
PostgresHost = os.Getenv("POSTGRES_HOST")
|
||||
PostgresPort = os.Getenv("POSTGRES_PORT")
|
||||
PostgresUser = os.Getenv("POSTGRES_USER")
|
||||
PostgresPass = os.Getenv("POSTGRES_PASS")
|
||||
PostgresSourceKind = "postgres"
|
||||
PostgresToolKind = "postgres-sql"
|
||||
PostgresListTablesToolKind = "postgres-list-tables"
|
||||
PostgresDatabase = os.Getenv("POSTGRES_DATABASE")
|
||||
PostgresHost = os.Getenv("POSTGRES_HOST")
|
||||
PostgresPort = os.Getenv("POSTGRES_PORT")
|
||||
PostgresUser = os.Getenv("POSTGRES_USER")
|
||||
PostgresPass = os.Getenv("POSTGRES_PASS")
|
||||
)
|
||||
|
||||
func getPostgresVars(t *testing.T) map[string]any {
|
||||
@@ -64,6 +71,20 @@ func getPostgresVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func addPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["list_tables"] = map[string]any{
|
||||
"kind": PostgresListTablesToolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables in the database.",
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
// Copied over from postgres.go
|
||||
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
@@ -114,6 +135,8 @@ func TestPostgres(t *testing.T) {
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
toolsFile = addPrebuiltToolConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -137,4 +160,165 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam)
|
||||
|
||||
// Run specific Postgres tool tests
|
||||
runPostgresListTablesTest(t, tableNameParam, tableNameAuth)
|
||||
}
|
||||
|
||||
func runPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth string) {
|
||||
// TableNameParam columns to construct want
|
||||
paramTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null}
|
||||
]`, tableNameParam)
|
||||
|
||||
// TableNameAuth columns to construct want
|
||||
authTableColumns := fmt.Sprintf(`[
|
||||
{"data_type": "integer", "column_name": "id", "column_default": "nextval('%s_id_seq'::regclass)", "is_not_nullable": true, "ordinal_position": 1, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "name", "column_default": null, "is_not_nullable": false, "ordinal_position": 2, "column_comment": null},
|
||||
{"data_type": "text", "column_name": "email", "column_default": null, "is_not_nullable": false, "ordinal_position": 3, "column_comment": null}
|
||||
]`, tableNameAuth)
|
||||
|
||||
const (
|
||||
// Template to construct detailed output want
|
||||
detailedObjectTemplate = `{
|
||||
"object_name": "%[1]s", "schema_name": "public",
|
||||
"object_details": {
|
||||
"owner": "%[3]s", "comment": null,
|
||||
"indexes": [{"is_primary": true, "is_unique": true, "index_name": "%[1]s_pkey", "index_method": "btree", "index_columns": ["id"], "index_definition": "CREATE UNIQUE INDEX %[1]s_pkey ON public.%[1]s USING btree (id)"}],
|
||||
"triggers": [], "columns": %[2]s, "object_name": "%[1]s", "object_type": "TABLE", "schema_name": "public",
|
||||
"constraints": [{"constraint_name": "%[1]s_pkey", "constraint_type": "PRIMARY KEY", "constraint_columns": ["id"], "constraint_definition": "PRIMARY KEY (id)", "foreign_key_referenced_table": null, "foreign_key_referenced_columns": null}]
|
||||
}
|
||||
}`
|
||||
|
||||
// Template to construct simple output want
|
||||
simpleObjectTemplate = `{"object_name":"%s", "schema_name":"public", "object_details":{"name":"%s"}}`
|
||||
)
|
||||
|
||||
// Helper to build json for detailed want
|
||||
getDetailedWant := func(tableName, columnJSON string) string {
|
||||
return fmt.Sprintf(detailedObjectTemplate, tableName, columnJSON, PostgresUser)
|
||||
}
|
||||
|
||||
// Helper to build template for simple want
|
||||
getSimpleWant := func(tableName string) string {
|
||||
return fmt.Sprintf(simpleObjectTemplate, tableName, tableName)
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke list_tables detailed output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s"}`,tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameAuth, authTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables simple output",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s", "output_format": "simple"}`, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getSimpleWant(tableNameAuth)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with invalid output format",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "", "output_format": "abcd"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with malformed table_names parameter",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": 12345, "output_format": "detailed"}`)),
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with multiple table names",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s,%s]", getDetailedWant(tableNameAuth, authTableColumns), getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table"}`)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with one existing and one non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,non_existent_table"}`, tableNameParam))),
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: fmt.Sprintf("[%s]", getDetailedWant(tableNameParam, paramTableColumns)),
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantStatusCode == http.StatusOK {
|
||||
var bodyWrapper map[string]json.RawMessage
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("error reading response body: %s", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBytes, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error parsing response wrapper: %s, body: %s", err, string(respBytes))
|
||||
}
|
||||
|
||||
resultJSON, ok := bodyWrapper["result"]
|
||||
if !ok {
|
||||
t.Fatal("unable to find 'result' in response body")
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(resultJSON, &resultString); err != nil {
|
||||
t.Fatalf("'result' is not a JSON-encoded string: %s", err)
|
||||
}
|
||||
|
||||
var got, want []any
|
||||
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal actual result string: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal expected want string: %v", err)
|
||||
}
|
||||
|
||||
sort.SliceStable(got, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", got[i]) < fmt.Sprintf("%v", got[j])
|
||||
})
|
||||
sort.SliceStable(want, func(i, j int) bool {
|
||||
return fmt.Sprintf("%v", want[i]) < fmt.Sprintf("%v", want[j])
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Unexpected result: got %#v, want: %#v", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user