Compare commits

...

8 Commits

Author SHA1 Message Date
Wenxin Du
e4f60e5633 fix(embeddingModel): add embedding model to MCP handler (#2310)
- Add embedding model to mcp handlers
- Add integration tests
2026-01-21 00:20:11 +00:00
vaibhavba-google
d7af21bdde tests(cloudhealthcare): use t.Cleanup() instead of defer (#2332)
## Description

Use t.Cleanup() to register cleanup of FHIR and DICOM stores immediately
after creation. This fixes the uncleaned FHIR/DICOM stores that remain
in the project(In the earlier implementation, teardown does not get
triggered if the test failed).

🛠️ Fixes #1986

---------

Co-authored-by: Yuan Teoh <yuanteoh@google.com>
Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2026-01-20 14:58:33 -08:00
Yuan Teoh
adc9589766 feat: add new user-agent-metadata flag (#2302)
## Description

Add a new `--user-agent-metadata` flag that allows user to append
additional user agent metadata. The flag takes in []string and will
concatenate it with `.`.

```
go run . --user-agent-metadata=foo
```
 produces `0.25.0+dev.darwin.arm64+foo` user agent string

```
go run . --user-agent-metadata=foo,bar
```
produces `0.25.0+dev.darwin.arm64+foo+bar` user agent string

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2026-01-20 19:23:50 +00:00
Yuan Teoh
c25a2330fe fix: add check for client authorization before retrieving token string (#2327)
Previous refactoring (#2273) accidentally removed the authorization
checks prior to token retrieval. This issue went unnoticed because the
integration tests were disabled. I am re-adding the necessary checks.
2026-01-20 18:57:11 +00:00
Juexin Wang
6e09b08c6a docs(tools/cloudgda): update cloud gda datasource references note (#2326)
## Description

Update the GDA source document to clarify that only `AlloyDbReference`,
`SpannerReference`, and `CloudSqlReference` are supported.

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #2324

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-16 18:57:06 +00:00
Wenxin Du
1f15a111f1 docs: fix redis array sample (#2301)
The Redis tool code sample is missing the "items" field for the array
parameter, causing confusion.
fix: https://github.com/googleapis/genai-toolbox/issues/2293
2026-01-16 17:08:47 +00:00
Twisha Bansal
dfddeb528d docs: update cloud run connection docs (#2320)
## Description

Partially fixes
https://github.com/googleapis/mcp-toolbox-sdk-python/issues/496

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
2026-01-16 10:05:05 +05:30
Eric Wang
00c3e6d8cb feat(prebuilt/cloud-sql): Add restore backup tool for cloud sql (#2171)
## Description

This pull request adds a new tool, cloud-sql-restore-backup, which
enables restoring a backup onto a Cloud SQL instance from the toolbox
using the Cloud SQL Admin API. The tool supports restoring standard,
project level, and BackupDR backups.

Tested:
<img width="3758" height="532" alt="image"
src="https://github.com/user-attachments/assets/d1d61af7-d96e-417c-898c-65b876de4c5e"
/>


## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #2170

Co-authored-by: Averi Kitsch <akitsch@google.com>
2026-01-16 00:16:46 +00:00
46 changed files with 1108 additions and 76 deletions

View File

@@ -87,7 +87,7 @@ steps:
- "CLOUD_SQL_POSTGRES_REGION=$_REGION" - "CLOUD_SQL_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: secretEnv:
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"] ["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes: volumes:
- name: "go" - name: "go"
path: "/gopath" path: "/gopath"
@@ -134,7 +134,7 @@ steps:
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME" - "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
- "ALLOYDB_POSTGRES_REGION=$_REGION" - "ALLOYDB_POSTGRES_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"] secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes: volumes:
- name: "go" - name: "go"
path: "/gopath" path: "/gopath"
@@ -293,7 +293,7 @@ steps:
.ci/test_with_coverage.sh \ .ci/test_with_coverage.sh \
"Cloud Healthcare API" \ "Cloud Healthcare API" \
cloudhealthcare \ cloudhealthcare \
cloudhealthcare || echo "Integration tests failed." cloudhealthcare
- id: "postgres" - id: "postgres"
name: golang:1 name: golang:1
@@ -305,7 +305,7 @@ steps:
- "POSTGRES_HOST=$_POSTGRES_HOST" - "POSTGRES_HOST=$_POSTGRES_HOST"
- "POSTGRES_PORT=$_POSTGRES_PORT" - "POSTGRES_PORT=$_POSTGRES_PORT"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"] secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
volumes: volumes:
- name: "go" - name: "go"
path: "/gopath" path: "/gopath"
@@ -964,6 +964,13 @@ steps:
availableSecrets: availableSecrets:
secretManager: secretManager:
# Common secrets
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
env: API_KEY
# Resource-specific secrets
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
env: CLOUD_SQL_POSTGRES_USER env: CLOUD_SQL_POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
@@ -980,8 +987,6 @@ availableSecrets:
env: POSTGRES_USER env: POSTGRES_USER
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest - versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
env: POSTGRES_PASS env: POSTGRES_PASS
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
env: CLIENT_ID
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest - versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
env: NEO4J_USER env: NEO4J_USER
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest - versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest

View File

@@ -98,6 +98,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
@@ -385,6 +386,7 @@ func NewCommand(opts ...Option) *Command {
// TODO: Insecure by default. Might consider updating this for v1.0.0 // TODO: Insecure by default. Might consider updating this for v1.0.0
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.") flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.") flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
// wrap RunE command so that we have access to original Command object // wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) } cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }

View File

@@ -70,6 +70,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
if c.AllowedHosts == nil { if c.AllowedHosts == nil {
c.AllowedHosts = []string{"*"} c.AllowedHosts = []string{"*"}
} }
if c.UserAgentMetadata == nil {
c.UserAgentMetadata = []string{}
}
return c return c
} }
@@ -230,6 +233,13 @@ func TestServerConfigFlags(t *testing.T) {
AllowedHosts: []string{"http://foo.com", "http://bar.com"}, AllowedHosts: []string{"http://foo.com", "http://bar.com"},
}), }),
}, },
{
desc: "user agent metadata",
args: []string{"--user-agent-metadata", "foo,bar"},
want: withDefaults(server.ServerConfig{
UserAgentMetadata: []string{"foo", "bar"},
}),
},
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
@@ -1493,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{ wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{ "cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_admin_tools", Name: "cloud_sql_postgres_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"}, ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"},
}, },
}, },
}, },
@@ -1503,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{ wantToolset: server.ToolsetConfigs{
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{ "cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mysql_admin_tools", Name: "cloud_sql_mysql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"}, ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
}, },
}, },
}, },
@@ -1513,7 +1523,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{ wantToolset: server.ToolsetConfigs{
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{ "cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mssql_admin_tools", Name: "cloud_sql_mssql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"}, ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
}, },
}, },
}, },

View File

@@ -54,6 +54,7 @@ instance, database and users:
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
## Install MCP Toolbox ## Install MCP Toolbox
@@ -301,6 +302,7 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete. * **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance. * **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
* **create_backup**: Creates a backup on a Cloud SQL instance. * **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}} {{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -54,6 +54,7 @@ database and users:
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
## Install MCP Toolbox ## Install MCP Toolbox
@@ -301,6 +302,7 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete. * **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance. * **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance. * **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}} {{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -54,6 +54,7 @@ instance, database and users:
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
## Install MCP Toolbox ## Install MCP Toolbox
@@ -301,6 +302,7 @@ instances and interacting with your database:
* **wait_for_operation**: Waits for a Cloud SQL operation to complete. * **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance. * **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance. * **create_backup**: Creates a backup on a Cloud SQL instance.
* **restore_backup**: Restores a backup of a Cloud SQL instance.
{{< notice note >}} {{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -207,6 +207,7 @@ You can connect to Toolbox Cloud Run instances directly through the SDK.
{{< tab header="Python" lang="python" >}} {{< tab header="Python" lang="python" >}}
import asyncio import asyncio
from toolbox_core import ToolboxClient, auth_methods from toolbox_core import ToolboxClient, auth_methods
from toolbox_core.protocol import Protocol
# Replace with the Cloud Run service URL generated in the previous step # Replace with the Cloud Run service URL generated in the previous step
URL = "https://cloud-run-url.app" URL = "https://cloud-run-url.app"
@@ -217,6 +218,7 @@ async def main():
async with ToolboxClient( async with ToolboxClient(
URL, URL,
client_headers={"Authorization": auth_token_provider}, client_headers={"Authorization": auth_token_provider},
protocol=Protocol.TOOLBOX,
) as toolbox: ) as toolbox:
toolset = await toolbox.load_toolset() toolset = await toolbox.load_toolset()
# ... # ...
@@ -281,3 +283,5 @@ contain the specific error message needed to diagnose the problem.
Manager, it means the Toolbox service account is missing permissions. Manager, it means the Toolbox service account is missing permissions.
- Ensure the `toolbox-identity` service account has the **Secret Manager - Ensure the `toolbox-identity` service account has the **Secret Manager
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role. Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
- **Cloud Run Connections via IAP:** Currently we do not support Cloud Run connections via [IAP](https://docs.cloud.google.com/iap/docs/concepts-overview). Please disable IAP if you are using it.

View File

@@ -27,6 +27,7 @@ description: >
| | `--ui` | Launches the Toolbox UI web server. | | | | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` | | | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` | | | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
| `-v` | `--version` | version for toolbox | | | `-v` | `--version` | version for toolbox | |
## Examples ## Examples

View File

@@ -194,6 +194,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
* **Tools:** * **Tools:**
* `create_instance`: Creates a new Cloud SQL for MySQL instance. * `create_instance`: Creates a new Cloud SQL for MySQL instance.
@@ -205,6 +206,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete. * `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance. * `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance. * `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Cloud SQL for PostgreSQL ## Cloud SQL for PostgreSQL
@@ -284,6 +286,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
* **Tools:** * **Tools:**
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance. * `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
* `get_instance`: Gets information about a Cloud SQL instance. * `get_instance`: Gets information about a Cloud SQL instance.
@@ -294,6 +297,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete. * `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance. * `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance. * `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Cloud SQL for SQL Server ## Cloud SQL for SQL Server
@@ -347,6 +351,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_instance` * `create_instance`
* `create_user` * `create_user`
* `clone_instance` * `clone_instance`
* `restore_backup`
* **Tools:** * **Tools:**
* `create_instance`: Creates a new Cloud SQL for SQL Server instance. * `create_instance`: Creates a new Cloud SQL for SQL Server instance.
* `get_instance`: Gets information about a Cloud SQL instance. * `get_instance`: Gets information about a Cloud SQL instance.
@@ -357,6 +362,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `wait_for_operation`: Waits for a Cloud SQL operation to complete. * `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance. * `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
* `create_backup`: Creates a backup on a Cloud SQL instance. * `create_backup`: Creates a backup on a Cloud SQL instance.
* `restore_backup`: Restores a backup of a Cloud SQL instance.
## Dataplex ## Dataplex

View File

@@ -12,6 +12,9 @@ aliases:
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases). The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
> [!NOTE]
> Only `alloydb`, `spannerReference`, and `cloudSqlReference` are supported as [datasource references](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1beta/projects.locations.dataAgents#DatasourceReferences).
## Example ## Example
```yaml ```yaml

View File

@@ -0,0 +1,53 @@
---
title: cloud-sql-restore-backup
type: docs
weight: 10
description: "Restores a backup of a Cloud SQL instance."
---
The `cloud-sql-restore-backup` tool restores a backup on a Cloud SQL instance using the Cloud SQL Admin API.
{{< notice info dd>}}
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Examples
Basic backup restore
```yaml
tools:
backup-restore-basic:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
description: "Restores a backup onto the given Cloud SQL instance."
```
## Reference
### Tool Configuration
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ------------------------------------------------ |
| kind | string | true | Must be "cloud-sql-restore-backup". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |
### Tool Inputs
| **parameter** | **type** | **required** | **description** |
| ------------------| :------: | :----------: | -----------------------------------------------------------------------------|
| target_project | string | true | The project ID of the instance to restore the backup onto. |
| target_instance | string | true | The instance to restore the backup onto. Does not include the project ID. |
| backup_id | string | true | The identifier of the backup being restored. |
| source_project | string | false | (Optional) The project ID of the instance that the backup belongs to. |
| source_instance | string | false | (Optional) Cloud SQL instance ID of the instance that the backup belongs to. |
## Usage Notes
- The `backup_id` field can be a BackupRun ID (which will be an int64), backup name, or BackupDR backup name.
- If the `backup_id` field contains a BackupRun ID (i.e. an int64), the optional fields `source_project` and `source_instance` must also be provided.
## See Also
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
- [Toolbox Cloud SQL tools documentation](../cloudsql)
- [Cloud SQL Restore API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/restoring)

View File

@@ -30,6 +30,10 @@ following config for example:
- name: userNames - name: userNames
type: array type: array
description: The user names to be set. description: The user names to be set.
items:
name: userName # the item name doesn't matter but it has to exist
type: string
description: username
``` ```
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command

View File

@@ -46,6 +46,9 @@ tools:
create_backup: create_backup:
kind: cloud-sql-create-backup kind: cloud-sql-create-backup
source: cloud-sql-admin-source source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets: toolsets:
cloud_sql_mssql_admin_tools: cloud_sql_mssql_admin_tools:
@@ -58,3 +61,4 @@ toolsets:
- wait_for_operation - wait_for_operation
- clone_instance - clone_instance
- create_backup - create_backup
- restore_backup

View File

@@ -46,6 +46,9 @@ tools:
create_backup: create_backup:
kind: cloud-sql-create-backup kind: cloud-sql-create-backup
source: cloud-sql-admin-source source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets: toolsets:
cloud_sql_mysql_admin_tools: cloud_sql_mysql_admin_tools:
@@ -58,3 +61,4 @@ toolsets:
- wait_for_operation - wait_for_operation
- clone_instance - clone_instance
- create_backup - create_backup
- restore_backup

View File

@@ -49,6 +49,9 @@ tools:
create_backup: create_backup:
kind: cloud-sql-create-backup kind: cloud-sql-create-backup
source: cloud-sql-admin-source source: cloud-sql-admin-source
restore_backup:
kind: cloud-sql-restore-backup
source: cloud-sql-admin-source
toolsets: toolsets:
cloud_sql_postgres_admin_tools: cloud_sql_postgres_admin_tools:
@@ -62,3 +65,4 @@ toolsets:
- postgres_upgrade_precheck - postgres_upgrade_precheck
- clone_instance - clone_instance
- create_backup - create_backup
- restore_backup

View File

@@ -64,12 +64,14 @@ type ServerConfig struct {
Stdio bool Stdio bool
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox. // DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
DisableReload bool DisableReload bool
// UI indicates if Toolbox UI endpoints (/ui) are available // UI indicates if Toolbox UI endpoints (/ui) are available.
UI bool UI bool
// Specifies a list of origins permitted to access this server. // Specifies a list of origins permitted to access this server.
AllowedOrigins []string AllowedOrigins []string
// Specifies a list of hosts permitted to access this server // Specifies a list of hosts permitted to access this server.
AllowedHosts []string AllowedHosts []string
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
UserAgentMetadata []string
} }
type logFormat string type logFormat string

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {

View File

@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {

View File

@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {

View File

@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
} }
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params)) logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
embeddingModels := resourceMgr.GetEmbeddingModelMap()
params, err = tool.EmbedParams(ctx, params, embeddingModels)
if err != nil {
err = fmt.Errorf("error embedding parameters: %w", err)
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
}
// run tool invocation and generate response. // run tool invocation and generate response.
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken) results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
if err != nil { if err != nil {

View File

@@ -64,7 +64,11 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]prompts.Promptset, map[string]prompts.Promptset,
error, error,
) { ) {
ctx = util.WithUserAgent(ctx, cfg.Version) metadataStr := cfg.Version
if len(cfg.UserAgentMetadata) > 0 {
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
}
ctx = util.WithUserAgent(ctx, metadataStr)
instrumentation, err := util.InstrumentationFromContext(ctx) instrumentation, err := util.InstrumentationFromContext(ctx)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"regexp" "regexp"
"strconv"
"strings" "strings"
"text/template" "text/template"
"time" "time"
@@ -36,7 +37,10 @@ import (
const SourceKind string = "cloud-sql-admin" const SourceKind string = "cloud-sql-admin"
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`) var (
targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
backupDRRegex = regexp.MustCompile(`^projects/([^/]+)/locations/([^/]+)/backupVaults/([^/]+)/dataSources/([^/]+)/backups/([^/]+)$`)
)
// validate interface // validate interface
var _ sources.SourceConfig = Config{} var _ sources.SourceConfig = Config{}
@@ -374,6 +378,48 @@ func (s *Source) InsertBackupRun(ctx context.Context, project, instance, locatio
return resp, nil return resp, nil
} }
func (s *Source) RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error) {
request := &sqladmin.InstancesRestoreBackupRequest{}
// There are 3 scenarios for the backup identifier:
// 1. The identifier is an int64 containing the timestamp of the BackupRun.
// This is used to restore standard backups, and the RestoreBackupContext
// field should be populated with the backup ID and source instance info.
// 2. The identifier is a string of the format
// 'projects/{project-id}/locations/{location}/backupVaults/{backupvault}/dataSources/{datasource}/backups/{backup-uid}'.
// This is used to restore BackupDR backups, and the BackupdrBackup field
// should be populated.
// 3. The identifer is a string of the format
// 'projects/{project-id}/backups/{backup-uid}'. In this case, the Backup
// field should be populated.
if backupRunID, err := strconv.ParseInt(backupID, 10, 64); err == nil {
if sourceProject == "" || targetInstance == "" {
return nil, fmt.Errorf("source project and instance are required when restoring via backup ID")
}
request.RestoreBackupContext = &sqladmin.RestoreBackupContext{
Project: sourceProject,
InstanceId: sourceInstance,
BackupRunId: backupRunID,
}
} else if backupDRRegex.MatchString(backupID) {
request.BackupdrBackup = backupID
} else {
request.Backup = backupID
}
service, err := s.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
resp, err := service.Instances.RestoreBackup(targetProject, targetInstance, request).Do()
if err != nil {
return nil, fmt.Errorf("error restoring backup: %w", err)
}
return resp, nil
}
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) { func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
operationType, ok := opResponse["operationType"].(string) operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" { if !ok || operationType != "CREATE_DATABASE" {

View File

@@ -103,10 +103,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok { if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
} }
var tokenStr string
tokenStr, err := accessToken.ParseBearerToken() if source.UseClientAuthorization() {
if err != nil { tokenStr, err = accessToken.ParseBearerToken()
return nil, fmt.Errorf("error parsing access token: %w", err) if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.FHIRFetchPage(ctx, url, tokenStr) return source.FHIRFetchPage(ctx, url, tokenStr)
} }

View File

@@ -131,9 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
var opts []googleapi.CallOption var opts []googleapi.CallOption

View File

@@ -161,9 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
var summary bool var summary bool

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetDataset(tokenStr) return source.GetDataset(tokenStr)
} }

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetDICOMStore(storeID, tokenStr) return source.GetDICOMStore(storeID, tokenStr)
} }

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetDICOMStoreMetrics(storeID, tokenStr) return source.GetDICOMStoreMetrics(storeID, tokenStr)
} }

View File

@@ -130,9 +130,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok { if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey) return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetFHIRResource(storeID, resType, resID, tokenStr) return source.GetFHIRResource(storeID, resType, resID, tokenStr)
} }

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetFHIRStore(storeID, tokenStr) return source.GetFHIRStore(storeID, tokenStr)
} }

View File

@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.GetFHIRStoreMetrics(storeID, tokenStr) return source.GetFHIRStoreMetrics(storeID, tokenStr)
} }

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.ListDICOMStores(tokenStr) return source.ListDICOMStores(tokenStr)
} }

View File

@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
return source.ListFHIRStores(tokenStr) return source.ListFHIRStores(tokenStr)
} }

View File

@@ -127,9 +127,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
study, ok := params.AsMap()[studyInstanceUIDKey].(string) study, ok := params.AsMap()[studyInstanceUIDKey].(string)
if !ok { if !ok {

View File

@@ -140,9 +140,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey}) opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})

View File

@@ -138,9 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey}) opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})

View File

@@ -133,9 +133,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil { if err != nil {
return nil, err return nil, err
} }
tokenStr, err := accessToken.ParseBearerToken() var tokenStr string
if err != nil { if source.UseClientAuthorization() {
return nil, fmt.Errorf("error parsing access token: %w", err) tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
} }
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey}) opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
if err != nil { if err != nil {

View File

@@ -0,0 +1,183 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlrestorebackup
import (
"context"
"fmt"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-restore-backup"
var _ tools.ToolConfig = Config{}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error)
}
// Config defines the configuration for the restore-backup tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Description string `yaml:"description"`
Source string `yaml:"source" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
// ToolConfigKind returns the kind of the tool.
func (cfg Config) ToolConfigKind() string {
return kind
}
// Initialize initializes the tool from the configuration.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.GetDefaultProject()
var targetProjectParam parameters.Parameter
if project != "" {
targetProjectParam = parameters.NewStringParameterWithDefault("target_project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
} else {
targetProjectParam = parameters.NewStringParameter("target_project", "The project ID")
}
allParameters := parameters.Parameters{
targetProjectParam,
parameters.NewStringParameter("target_instance", "Cloud SQL instance ID of the target instance. This does not include the project ID."),
parameters.NewStringParameter("backup_id", "Identifier of the backup being restored. Can be a BackupRun ID, backup name, or BackupDR backup name. Use the full backup ID as provided, do not try to parse it"),
parameters.NewStringParameterWithRequired("source_project", "GCP project ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
parameters.NewStringParameterWithRequired("source_instance", "Cloud SQL instance ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
}
paramManifest := allParameters.Manifest()
description := cfg.Description
if description == "" {
description = "Restores a backup on a Cloud SQL instance."
}
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
return Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
// Tool represents the restore-backup tool.
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
targetProject, ok := paramsMap["target_project"].(string)
if !ok {
return nil, fmt.Errorf("error casting 'target_project' parameter: %v", paramsMap["target_project"])
}
targetInstance, ok := paramsMap["target_instance"].(string)
if !ok {
return nil, fmt.Errorf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"])
}
backupID, ok := paramsMap["backup_id"].(string)
if !ok {
return nil, fmt.Errorf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"])
}
sourceProject, _ := paramsMap["source_project"].(string)
sourceInstance, _ := paramsMap["source_instance"].(string)
return source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken))
}
// ParseParams parses the parameters for the tool.
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.AllParams, data, claims)
}
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
}
// Manifest returns the tool's manifest.
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
// McpManifest returns the tool's MCP manifest.
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
// Authorized checks if the tool is authorized.
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -0,0 +1,71 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlrestorebackup_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
)
func TestParseFromYaml(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
restore-backup-tool:
kind: cloud-sql-restore-backup
description: a test description
source: a-source
`,
want: server.ToolConfigs{
"restore-backup-tool": cloudsqlrestorebackup.Config{
Name: "restore-backup-tool",
Kind: "cloud-sql-restore-backup",
Description: "a test description",
Source: "a-source",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) defer teardownTable2(t)
// Set up table for semanti search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command // Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)

View File

@@ -112,8 +112,7 @@ func TestHealthcareToolEndpoints(t *testing.T) {
fhirStoreID := "fhir-store-" + uuid.New().String() fhirStoreID := "fhir-store-" + uuid.New().String()
dicomStoreID := "dicom-store-" + uuid.New().String() dicomStoreID := "dicom-store-" + uuid.New().String()
patient1ID, patient2ID, teardown := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID) patient1ID, patient2ID := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
defer teardown(t)
toolsFile := getToolsConfig(sourceConfig) toolsFile := getToolsConfig(sourceConfig)
toolsFile = addClientAuthSourceConfig(t, toolsFile) toolsFile = addClientAuthSourceConfig(t, toolsFile)
@@ -173,10 +172,8 @@ func TestHealthcareToolWithStoreRestriction(t *testing.T) {
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String() disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String() disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
_, _, teardownAllowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID) setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
defer teardownAllowedStores(t) setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
_, _, teardownDisallowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
defer teardownDisallowedStores(t)
// Configure source with dataset restriction. // Configure source with dataset restriction.
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID} sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
@@ -257,7 +254,7 @@ func newHealthcareService(ctx context.Context) (*healthcare.Service, error) {
return healthcareService, nil return healthcareService, nil
} }
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string, func(*testing.T)) { func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string) {
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID) datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
var err error var err error
@@ -266,12 +263,24 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil { if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
t.Fatalf("failed to create fhir store: %v", err) t.Fatalf("failed to create fhir store: %v", err)
} }
// Register cleanup
t.Cleanup(func() {
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
t.Logf("failed to delete fhir store: %v", err)
}
})
// Create DICOM store // Create DICOM store
dicomStore := &healthcare.DicomStore{} dicomStore := &healthcare.DicomStore{}
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil { if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
t.Fatalf("failed to create dicom store: %v", err) t.Fatalf("failed to create dicom store: %v", err)
} }
// Register cleanup
t.Cleanup(func() {
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
t.Logf("failed to delete dicom store: %v", err)
}
})
// Create Patient 1 // Create Patient 1
patient1Body := bytes.NewBuffer([]byte(`{ patient1Body := bytes.NewBuffer([]byte(`{
@@ -317,15 +326,7 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body) createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
} }
teardown := func(t *testing.T) { return patient1ID, patient2ID
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
t.Logf("failed to delete fhir store: %v", err)
}
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
t.Logf("failed to delete dicom store: %v", err)
}
}
return patient1ID, patient2ID, teardown
} }
func getToolsConfig(sourceConfig map[string]any) map[string]any { func getToolsConfig(sourceConfig map[string]any) map[string]any {

View File

@@ -0,0 +1,267 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsql
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"regexp"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/tests"
"google.golang.org/api/sqladmin/v1"
)
var (
restoreBackupToolKind = "cloud-sql-restore-backup"
)
type restoreBackupTransport struct {
transport http.RoundTripper
url *url.URL
}
func (t *restoreBackupTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") {
req.URL.Scheme = t.url.Scheme
req.URL.Host = t.url.Host
}
return t.transport.RoundTrip(req)
}
type masterRestoreBackupHandler struct {
t *testing.T
}
func (h *masterRestoreBackupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
h.t.Errorf("User-Agent header not found")
}
var body sqladmin.InstancesRestoreBackupRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
h.t.Fatalf("failed to decode request body: %v", err)
} else {
h.t.Logf("Received request body: %+v", body)
}
var expectedBody sqladmin.InstancesRestoreBackupRequest
var response any
var statusCode int
switch {
case body.Backup != "":
expectedBody = sqladmin.InstancesRestoreBackupRequest{
Backup: "projects/p1/backups/test-uid",
}
response = map[string]any{"name": "op1", "status": "PENDING"}
statusCode = http.StatusOK
case body.BackupdrBackup != "":
expectedBody = sqladmin.InstancesRestoreBackupRequest{
BackupdrBackup: "projects/p1/locations/us-central1/backupVaults/test-vault/dataSources/test-ds/backups/test-uid",
}
response = map[string]any{"name": "op1", "status": "PENDING"}
statusCode = http.StatusOK
case body.RestoreBackupContext != nil:
expectedBody = sqladmin.InstancesRestoreBackupRequest{
RestoreBackupContext: &sqladmin.RestoreBackupContext{
Project: "p1",
InstanceId: "source",
BackupRunId: 12345,
},
}
response = map[string]any{"name": "op1", "status": "PENDING"}
statusCode = http.StatusOK
default:
http.Error(w, fmt.Sprintf("unhandled restore request body: %v", body), http.StatusInternalServerError)
return
}
if diff := cmp.Diff(expectedBody, body); diff != "" {
h.t.Errorf("unexpected request body (-want +got):\n%s", diff)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func TestRestoreBackupToolEndpoints(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
handler := &masterRestoreBackupHandler{t: t}
server := httptest.NewServer(handler)
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
originalTransport := http.DefaultClient.Transport
if originalTransport == nil {
originalTransport = http.DefaultTransport
}
http.DefaultClient.Transport = &restoreBackupTransport{
transport: originalTransport,
url: serverURL,
}
t.Cleanup(func() {
http.DefaultClient.Transport = originalTransport
})
var args []string
toolsFile := getRestoreBackupToolsConfig()
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
tcs := []struct {
name string
toolName string
body string
want string
expectError bool
errorStatus int
}{
{
name: "successful restore with standard backup",
toolName: "restore-backup",
body: `{"target_project": "p1", "target_instance": "instance-standard", "backup_id": "12345", "source_project": "p1", "source_instance": "source"}`,
want: `{"name":"op1","status":"PENDING"}`,
},
{
name: "successful restore with project level backup",
toolName: "restore-backup",
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "projects/p1/backups/test-uid"}`,
want: `{"name":"op1","status":"PENDING"}`,
},
{
name: "successful restore with BackupDR backup",
toolName: "restore-backup",
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "projects/p1/locations/us-central1/backupVaults/test-vault/dataSources/test-ds/backups/test-uid"}`,
want: `{"name":"op1","status":"PENDING"}`,
},
{
name: "missing source instance info for standard backup",
toolName: "restore-backup",
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "12345"}`,
expectError: true,
errorStatus: http.StatusBadRequest,
},
{
name: "missing backup identifier",
toolName: "restore-backup",
body: `{"target_project": "p1", "target_instance": "instance-project-level"}`,
expectError: true,
errorStatus: http.StatusBadRequest,
},
{
name: "missing target instance info",
toolName: "restore-backup",
body: `{"backup_id": "12345"}`,
expectError: true,
errorStatus: http.StatusBadRequest,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body))
if err != nil {
t.Fatalf("unable to create request: %s", err)
}
req.Header.Add("Content-type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("unable to send request: %s", err)
}
defer resp.Body.Close()
if tc.expectError {
if resp.StatusCode != tc.errorStatus {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("expected status %d but got %d: %s", tc.errorStatus, resp.StatusCode, string(bodyBytes))
}
return
}
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var result struct {
Result string `json:"result"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
var got, want map[string]any
if err := json.Unmarshal([]byte(result.Result), &got); err != nil {
t.Fatalf("failed to unmarshal result: %v", err)
}
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
t.Fatalf("failed to unmarshal want: %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("unexpected result: got %+v, want %+v", got, want)
}
})
}
}
func getRestoreBackupToolsConfig() map[string]any {
return map[string]any{
"sources": map[string]any{
"my-cloud-sql-source": map[string]any{
"kind": "cloud-sql-admin",
},
},
"tools": map[string]any{
"restore-backup": map[string]any{
"kind": restoreBackupToolKind,
"source": "my-cloud-sql-source",
},
},
}
}

View File

@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) defer teardownTable2(t)
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command // Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil { if err != nil {
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool) tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool) tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
} }
// Test connection with different IP type // Test connection with different IP type

251
tests/embedding.go Normal file
View File

@@ -0,0 +1,251 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tests contains end to end tests meant to verify the Toolbox Server
// works as expected when executed as a binary.
package tests
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"testing"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/jackc/pgx/v5/pgxpool"
)
var apiKey = os.Getenv("API_KEY")
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
// with configurable tool kind and SQL statements.
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
config["embeddingModels"] = map[string]any{
"gemini_model": map[string]any{
"kind": "gemini",
"model": "gemini-embedding-001",
"apiKey": apiKey,
"dimension": 768,
},
}
tools, ok := config["tools"].(map[string]any)
if !ok {
t.Fatalf("unable to get tools from config")
}
tools["insert_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Stores content and its vector embedding into the documents table.",
"statement": insertStmt,
"parameters": []any{
map[string]any{
"name": "content",
"type": "string",
"description": "The text content associated with the vector.",
},
map[string]any{
"name": "text_to_embed",
"type": "string",
"description": "The text content used to generate the vector.",
"embeddedBy": "gemini_model",
},
},
}
tools["search_docs"] = map[string]any{
"kind": toolKind,
"source": "my-instance",
"description": "Finds the most semantically similar document to the query vector.",
"statement": searchStmt,
"parameters": []any{
map[string]any{
"name": "query",
"type": "string",
"description": "The text content to search for.",
"embeddedBy": "gemini_model",
},
},
}
config["tools"] = tools
return config
}
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
// via both HTTP and MCP endpoints and verifies the output.
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
// Initialize MCP session once for the MCP test cases
sessionId := RunInitialize(t, "2024-11-05")
tcs := []struct {
name string
api string
isMcp bool
requestBody interface{}
want string
}{
{
name: "HTTP invoke insert_docs",
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
isMcp: false,
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
want: insertWant,
},
{
name: "HTTP invoke search_docs",
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
isMcp: false,
requestBody: `{"query": "fast fox jumping"}`,
want: searchWant,
},
{
name: "MCP invoke insert_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-insert-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "insert_docs",
"arguments": map[string]any{
"content": "The quick brown fox jumps over the lazy dog",
"text_to_embed": "The quick brown fox jumps over the lazy dog",
},
},
},
want: mcpInsertWant,
},
{
name: "MCP invoke search_docs",
api: "http://127.0.0.1:5000/mcp",
isMcp: true,
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "mcp-search-docs",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": "search_docs",
"arguments": map[string]any{
"query": "fast fox jumping",
},
},
},
want: searchWant,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
var bodyReader io.Reader
headers := map[string]string{}
// Prepare Request Body and Headers
if tc.isMcp {
reqBytes, err := json.Marshal(tc.requestBody)
if err != nil {
t.Fatalf("failed to marshal mcp request: %v", err)
}
bodyReader = bytes.NewBuffer(reqBytes)
if sessionId != "" {
headers["Mcp-Session-Id"] = sessionId
}
} else {
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
}
// Send Request
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
if resp.StatusCode != http.StatusOK {
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
}
// Normalize Response to get the actual tool result string
var got string
if tc.isMcp {
var mcpResp struct {
Result struct {
Content []struct {
Text string `json:"text"`
} `json:"content"`
} `json:"result"`
}
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
t.Fatalf("error parsing mcp response: %s", err)
}
if len(mcpResp.Result.Content) > 0 {
got = mcpResp.Result.Content[0].Text
}
} else {
var httpResp map[string]interface{}
if err := json.Unmarshal(respBody, &httpResp); err != nil {
t.Fatalf("error parsing http response: %s", err)
}
if res, ok := httpResp["result"].(string); ok {
got = res
}
}
if !strings.Contains(got, tc.want) {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
// SetupPostgresVectorTable sets up the vector extension and a vector table
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
t.Helper()
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
t.Fatalf("failed to create vector extension: %v", err)
}
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
id SERIAL PRIMARY KEY,
content TEXT,
embedding vector(768)
)`, tableName)
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
t.Fatalf("failed to create table %s: %v", tableName, err)
}
return tableName, func(t *testing.T) {
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
t.Errorf("failed to drop table %s: %v", tableName, err)
}
}
}
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
return insertStmt, searchStmt
}

View File

@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
defer teardownTable2(t) defer teardownTable2(t)
// Set up table for semantic search
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
defer tearDownVectorTable(t)
// Write config into a file and pass it to command // Write config into a file and pass it to command
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql") toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile) toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil { if err != nil {
t.Fatalf("command initialization returned an error: %s", err) t.Fatalf("command initialization returned an error: %s", err)
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool) tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool) tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool) tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
} }

View File

@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
var filteredGot []any var filteredGot []any
for _, item := range got { for _, item := range got {
if tableMap, ok := item.(map[string]interface{}); ok { if tableMap, ok := item.(map[string]interface{}); ok {
if schema, ok := tableMap["schema_name"]; ok && schema == "public" { name, _ := tableMap["object_name"].(string)
// Only keep the table if it matches expected test tables
if name == tableNameParam || name == tableNameAuth {
filteredGot = append(filteredGot, item) filteredGot = append(filteredGot, item)
} }
} }