mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-21 05:18:14 -05:00
Compare commits
8 Commits
refactor-p
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4f60e5633 | ||
|
|
d7af21bdde | ||
|
|
adc9589766 | ||
|
|
c25a2330fe | ||
|
|
6e09b08c6a | ||
|
|
1f15a111f1 | ||
|
|
dfddeb528d | ||
|
|
00c3e6d8cb |
@@ -87,7 +87,7 @@ steps:
|
|||||||
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
- "CLOUD_SQL_POSTGRES_REGION=$_REGION"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv:
|
secretEnv:
|
||||||
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID"]
|
["CLOUD_SQL_POSTGRES_USER", "CLOUD_SQL_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -134,7 +134,7 @@ steps:
|
|||||||
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
- "ALLOYDB_POSTGRES_DATABASE=$_DATABASE_NAME"
|
||||||
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
- "ALLOYDB_POSTGRES_REGION=$_REGION"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID"]
|
secretEnv: ["ALLOYDB_POSTGRES_USER", "ALLOYDB_POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -293,7 +293,7 @@ steps:
|
|||||||
.ci/test_with_coverage.sh \
|
.ci/test_with_coverage.sh \
|
||||||
"Cloud Healthcare API" \
|
"Cloud Healthcare API" \
|
||||||
cloudhealthcare \
|
cloudhealthcare \
|
||||||
cloudhealthcare || echo "Integration tests failed."
|
cloudhealthcare
|
||||||
|
|
||||||
- id: "postgres"
|
- id: "postgres"
|
||||||
name: golang:1
|
name: golang:1
|
||||||
@@ -305,7 +305,7 @@ steps:
|
|||||||
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
- "POSTGRES_HOST=$_POSTGRES_HOST"
|
||||||
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
- "POSTGRES_PORT=$_POSTGRES_PORT"
|
||||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||||
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID"]
|
secretEnv: ["POSTGRES_USER", "POSTGRES_PASS", "CLIENT_ID", "API_KEY"]
|
||||||
volumes:
|
volumes:
|
||||||
- name: "go"
|
- name: "go"
|
||||||
path: "/gopath"
|
path: "/gopath"
|
||||||
@@ -964,6 +964,13 @@ steps:
|
|||||||
|
|
||||||
availableSecrets:
|
availableSecrets:
|
||||||
secretManager:
|
secretManager:
|
||||||
|
# Common secrets
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
||||||
|
env: CLIENT_ID
|
||||||
|
- versionName: projects/$PROJECT_ID/secrets/api_key/versions/latest
|
||||||
|
env: API_KEY
|
||||||
|
|
||||||
|
# Resource-specific secrets
|
||||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||||
env: CLOUD_SQL_POSTGRES_USER
|
env: CLOUD_SQL_POSTGRES_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_pass/versions/latest
|
||||||
@@ -980,8 +987,6 @@ availableSecrets:
|
|||||||
env: POSTGRES_USER
|
env: POSTGRES_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/postgres_pass/versions/latest
|
||||||
env: POSTGRES_PASS
|
env: POSTGRES_PASS
|
||||||
- versionName: projects/$PROJECT_ID/secrets/client_id/versions/latest
|
|
||||||
env: CLIENT_ID
|
|
||||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/neo4j_user/versions/latest
|
||||||
env: NEO4J_USER
|
env: NEO4J_USER
|
||||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ import (
|
|||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
||||||
@@ -385,6 +386,7 @@ func NewCommand(opts ...Option) *Command {
|
|||||||
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
// TODO: Insecure by default. Might consider updating this for v1.0.0
|
||||||
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
|
||||||
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
|
||||||
|
flags.StringSliceVar(&cmd.cfg.UserAgentMetadata, "user-agent-metadata", []string{}, "Appends additional metadata to the User-Agent.")
|
||||||
|
|
||||||
// wrap RunE command so that we have access to original Command object
|
// wrap RunE command so that we have access to original Command object
|
||||||
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }
|
||||||
|
|||||||
@@ -70,6 +70,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
|
|||||||
if c.AllowedHosts == nil {
|
if c.AllowedHosts == nil {
|
||||||
c.AllowedHosts = []string{"*"}
|
c.AllowedHosts = []string{"*"}
|
||||||
}
|
}
|
||||||
|
if c.UserAgentMetadata == nil {
|
||||||
|
c.UserAgentMetadata = []string{}
|
||||||
|
}
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,6 +233,13 @@ func TestServerConfigFlags(t *testing.T) {
|
|||||||
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
|
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "user agent metadata",
|
||||||
|
args: []string{"--user-agent-metadata", "foo,bar"},
|
||||||
|
want: withDefaults(server.ServerConfig{
|
||||||
|
UserAgentMetadata: []string{"foo", "bar"},
|
||||||
|
}),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
@@ -1493,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_postgres_admin_tools",
|
Name: "cloud_sql_postgres_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1503,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_mysql_admin_tools",
|
Name: "cloud_sql_mysql_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1513,7 +1523,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_mssql_admin_tools",
|
Name: "cloud_sql_mssql_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ instance, database and users:
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -301,6 +302,7 @@ instances and interacting with your database:
|
|||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
||||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ database and users:
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -301,6 +302,7 @@ instances and interacting with your database:
|
|||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
||||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ instance, database and users:
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -301,6 +302,7 @@ instances and interacting with your database:
|
|||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
||||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -207,6 +207,7 @@ You can connect to Toolbox Cloud Run instances directly through the SDK.
|
|||||||
{{< tab header="Python" lang="python" >}}
|
{{< tab header="Python" lang="python" >}}
|
||||||
import asyncio
|
import asyncio
|
||||||
from toolbox_core import ToolboxClient, auth_methods
|
from toolbox_core import ToolboxClient, auth_methods
|
||||||
|
from toolbox_core.protocol import Protocol
|
||||||
|
|
||||||
# Replace with the Cloud Run service URL generated in the previous step
|
# Replace with the Cloud Run service URL generated in the previous step
|
||||||
URL = "https://cloud-run-url.app"
|
URL = "https://cloud-run-url.app"
|
||||||
@@ -217,6 +218,7 @@ async def main():
|
|||||||
async with ToolboxClient(
|
async with ToolboxClient(
|
||||||
URL,
|
URL,
|
||||||
client_headers={"Authorization": auth_token_provider},
|
client_headers={"Authorization": auth_token_provider},
|
||||||
|
protocol=Protocol.TOOLBOX,
|
||||||
) as toolbox:
|
) as toolbox:
|
||||||
toolset = await toolbox.load_toolset()
|
toolset = await toolbox.load_toolset()
|
||||||
# ...
|
# ...
|
||||||
@@ -281,3 +283,5 @@ contain the specific error message needed to diagnose the problem.
|
|||||||
Manager, it means the Toolbox service account is missing permissions.
|
Manager, it means the Toolbox service account is missing permissions.
|
||||||
- Ensure the `toolbox-identity` service account has the **Secret Manager
|
- Ensure the `toolbox-identity` service account has the **Secret Manager
|
||||||
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
|
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
|
||||||
|
|
||||||
|
- **Cloud Run Connections via IAP:** Currently we do not support Cloud Run connections via [IAP](https://docs.cloud.google.com/iap/docs/concepts-overview). Please disable IAP if you are using it.
|
||||||
@@ -27,6 +27,7 @@ description: >
|
|||||||
| | `--ui` | Launches the Toolbox UI web server. | |
|
| | `--ui` | Launches the Toolbox UI web server. | |
|
||||||
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
|
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
|
||||||
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
|
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
|
||||||
|
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
|
||||||
| `-v` | `--version` | version for toolbox | |
|
| `-v` | `--version` | version for toolbox | |
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
|
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
|
||||||
@@ -205,6 +206,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
||||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Cloud SQL for PostgreSQL
|
## Cloud SQL for PostgreSQL
|
||||||
|
|
||||||
@@ -284,6 +286,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
|
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
|
||||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||||
@@ -294,6 +297,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
||||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Cloud SQL for SQL Server
|
## Cloud SQL for SQL Server
|
||||||
|
|
||||||
@@ -347,6 +351,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
|
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
|
||||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||||
@@ -357,6 +362,7 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
||||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Dataplex
|
## Dataplex
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ aliases:
|
|||||||
|
|
||||||
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Only `alloydb`, `spannerReference`, and `cloudSqlReference` are supported as [datasource references](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1beta/projects.locations.dataAgents#DatasourceReferences).
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
---
|
||||||
|
title: cloud-sql-restore-backup
|
||||||
|
type: docs
|
||||||
|
weight: 10
|
||||||
|
description: "Restores a backup of a Cloud SQL instance."
|
||||||
|
---
|
||||||
|
|
||||||
|
The `cloud-sql-restore-backup` tool restores a backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||||
|
|
||||||
|
{{< notice info dd>}}
|
||||||
|
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Basic backup restore
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
backup-restore-basic:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
description: "Restores a backup onto the given Cloud SQL instance."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
### Tool Configuration
|
||||||
|
| **field** | **type** | **required** | **description** |
|
||||||
|
| -------------- | :------: | :----------: | ------------------------------------------------ |
|
||||||
|
| kind | string | true | Must be "cloud-sql-restore-backup". |
|
||||||
|
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||||
|
| description | string | false | A description of the tool. |
|
||||||
|
|
||||||
|
### Tool Inputs
|
||||||
|
|
||||||
|
| **parameter** | **type** | **required** | **description** |
|
||||||
|
| ------------------| :------: | :----------: | -----------------------------------------------------------------------------|
|
||||||
|
| target_project | string | true | The project ID of the instance to restore the backup onto. |
|
||||||
|
| target_instance | string | true | The instance to restore the backup onto. Does not include the project ID. |
|
||||||
|
| backup_id | string | true | The identifier of the backup being restored. |
|
||||||
|
| source_project | string | false | (Optional) The project ID of the instance that the backup belongs to. |
|
||||||
|
| source_instance | string | false | (Optional) Cloud SQL instance ID of the instance that the backup belongs to. |
|
||||||
|
|
||||||
|
## Usage Notes
|
||||||
|
|
||||||
|
- The `backup_id` field can be a BackupRun ID (which will be an int64), backup name, or BackupDR backup name.
|
||||||
|
- If the `backup_id` field contains a BackupRun ID (i.e. an int64), the optional fields `source_project` and `source_instance` must also be provided.
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||||
|
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||||
|
- [Cloud SQL Restore API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/restoring)
|
||||||
@@ -30,6 +30,10 @@ following config for example:
|
|||||||
- name: userNames
|
- name: userNames
|
||||||
type: array
|
type: array
|
||||||
description: The user names to be set.
|
description: The user names to be set.
|
||||||
|
items:
|
||||||
|
name: userName # the item name doesn't matter but it has to exist
|
||||||
|
type: string
|
||||||
|
description: username
|
||||||
```
|
```
|
||||||
|
|
||||||
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command
|
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ tools:
|
|||||||
create_backup:
|
create_backup:
|
||||||
kind: cloud-sql-create-backup
|
kind: cloud-sql-create-backup
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_mssql_admin_tools:
|
cloud_sql_mssql_admin_tools:
|
||||||
@@ -58,3 +61,4 @@ toolsets:
|
|||||||
- wait_for_operation
|
- wait_for_operation
|
||||||
- clone_instance
|
- clone_instance
|
||||||
- create_backup
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ tools:
|
|||||||
create_backup:
|
create_backup:
|
||||||
kind: cloud-sql-create-backup
|
kind: cloud-sql-create-backup
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_mysql_admin_tools:
|
cloud_sql_mysql_admin_tools:
|
||||||
@@ -58,3 +61,4 @@ toolsets:
|
|||||||
- wait_for_operation
|
- wait_for_operation
|
||||||
- clone_instance
|
- clone_instance
|
||||||
- create_backup
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -49,6 +49,9 @@ tools:
|
|||||||
create_backup:
|
create_backup:
|
||||||
kind: cloud-sql-create-backup
|
kind: cloud-sql-create-backup
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_postgres_admin_tools:
|
cloud_sql_postgres_admin_tools:
|
||||||
@@ -62,3 +65,4 @@ toolsets:
|
|||||||
- postgres_upgrade_precheck
|
- postgres_upgrade_precheck
|
||||||
- clone_instance
|
- clone_instance
|
||||||
- create_backup
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -64,12 +64,14 @@ type ServerConfig struct {
|
|||||||
Stdio bool
|
Stdio bool
|
||||||
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
|
||||||
DisableReload bool
|
DisableReload bool
|
||||||
// UI indicates if Toolbox UI endpoints (/ui) are available
|
// UI indicates if Toolbox UI endpoints (/ui) are available.
|
||||||
UI bool
|
UI bool
|
||||||
// Specifies a list of origins permitted to access this server.
|
// Specifies a list of origins permitted to access this server.
|
||||||
AllowedOrigins []string
|
AllowedOrigins []string
|
||||||
// Specifies a list of hosts permitted to access this server
|
// Specifies a list of hosts permitted to access this server.
|
||||||
AllowedHosts []string
|
AllowedHosts []string
|
||||||
|
// UserAgentMetadata specifies additional metadata to append to the User-Agent string.
|
||||||
|
UserAgentMetadata []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type logFormat string
|
type logFormat string
|
||||||
|
|||||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -183,6 +183,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -176,6 +176,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *re
|
|||||||
}
|
}
|
||||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
embeddingModels := resourceMgr.GetEmbeddingModelMap()
|
||||||
|
params, err = tool.EmbedParams(ctx, params, embeddingModels)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error embedding parameters: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
// run tool invocation and generate response.
|
// run tool invocation and generate response.
|
||||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -64,7 +64,11 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
|||||||
map[string]prompts.Promptset,
|
map[string]prompts.Promptset,
|
||||||
error,
|
error,
|
||||||
) {
|
) {
|
||||||
ctx = util.WithUserAgent(ctx, cfg.Version)
|
metadataStr := cfg.Version
|
||||||
|
if len(cfg.UserAgentMetadata) > 0 {
|
||||||
|
metadataStr += "+" + strings.Join(cfg.UserAgentMetadata, "+")
|
||||||
|
}
|
||||||
|
ctx = util.WithUserAgent(ctx, metadataStr)
|
||||||
instrumentation, err := util.InstrumentationFromContext(ctx)
|
instrumentation, err := util.InstrumentationFromContext(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,7 +37,10 @@ import (
|
|||||||
|
|
||||||
const SourceKind string = "cloud-sql-admin"
|
const SourceKind string = "cloud-sql-admin"
|
||||||
|
|
||||||
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
var (
|
||||||
|
targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||||
|
backupDRRegex = regexp.MustCompile(`^projects/([^/]+)/locations/([^/]+)/backupVaults/([^/]+)/dataSources/([^/]+)/backups/([^/]+)$`)
|
||||||
|
)
|
||||||
|
|
||||||
// validate interface
|
// validate interface
|
||||||
var _ sources.SourceConfig = Config{}
|
var _ sources.SourceConfig = Config{}
|
||||||
@@ -374,6 +378,48 @@ func (s *Source) InsertBackupRun(ctx context.Context, project, instance, locatio
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error) {
|
||||||
|
request := &sqladmin.InstancesRestoreBackupRequest{}
|
||||||
|
|
||||||
|
// There are 3 scenarios for the backup identifier:
|
||||||
|
// 1. The identifier is an int64 containing the timestamp of the BackupRun.
|
||||||
|
// This is used to restore standard backups, and the RestoreBackupContext
|
||||||
|
// field should be populated with the backup ID and source instance info.
|
||||||
|
// 2. The identifier is a string of the format
|
||||||
|
// 'projects/{project-id}/locations/{location}/backupVaults/{backupvault}/dataSources/{datasource}/backups/{backup-uid}'.
|
||||||
|
// This is used to restore BackupDR backups, and the BackupdrBackup field
|
||||||
|
// should be populated.
|
||||||
|
// 3. The identifer is a string of the format
|
||||||
|
// 'projects/{project-id}/backups/{backup-uid}'. In this case, the Backup
|
||||||
|
// field should be populated.
|
||||||
|
if backupRunID, err := strconv.ParseInt(backupID, 10, 64); err == nil {
|
||||||
|
if sourceProject == "" || targetInstance == "" {
|
||||||
|
return nil, fmt.Errorf("source project and instance are required when restoring via backup ID")
|
||||||
|
}
|
||||||
|
request.RestoreBackupContext = &sqladmin.RestoreBackupContext{
|
||||||
|
Project: sourceProject,
|
||||||
|
InstanceId: sourceInstance,
|
||||||
|
BackupRunId: backupRunID,
|
||||||
|
}
|
||||||
|
} else if backupDRRegex.MatchString(backupID) {
|
||||||
|
request.BackupdrBackup = backupID
|
||||||
|
} else {
|
||||||
|
request.Backup = backupID
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := s.GetService(ctx, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := service.Instances.RestoreBackup(targetProject, targetInstance, request).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error restoring backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
||||||
operationType, ok := opResponse["operationType"].(string)
|
operationType, ok := opResponse["operationType"].(string)
|
||||||
if !ok || operationType != "CREATE_DATABASE" {
|
if !ok || operationType != "CREATE_DATABASE" {
|
||||||
|
|||||||
@@ -103,10 +103,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||||
}
|
}
|
||||||
|
var tokenStr string
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
if source.UseClientAuthorization() {
|
||||||
if err != nil {
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -131,9 +131,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var opts []googleapi.CallOption
|
var opts []googleapi.CallOption
|
||||||
|
|||||||
@@ -161,9 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var summary bool
|
var summary bool
|
||||||
|
|||||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetDataset(tokenStr)
|
return source.GetDataset(tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetDICOMStore(storeID, tokenStr)
|
return source.GetDICOMStore(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,9 +130,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetFHIRStore(storeID, tokenStr)
|
return source.GetFHIRStore(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -116,9 +116,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.ListDICOMStores(tokenStr)
|
return source.ListDICOMStores(tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -95,9 +95,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return source.ListFHIRStores(tokenStr)
|
return source.ListFHIRStores(tokenStr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -127,9 +127,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
|||||||
@@ -140,9 +140,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||||
|
|||||||
@@ -138,9 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||||
|
|||||||
@@ -133,9 +133,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
var tokenStr string
|
||||||
if err != nil {
|
if source.UseClientAuthorization() {
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
tokenStr, err = accessToken.ParseBearerToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,183 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlrestorebackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"google.golang.org/api/sqladmin/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kind string = "cloud-sql-restore-backup"
|
||||||
|
|
||||||
|
var _ tools.ToolConfig = Config{}
|
||||||
|
|
||||||
|
type compatibleSource interface {
|
||||||
|
GetDefaultProject() string
|
||||||
|
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||||
|
UseClientAuthorization() bool
|
||||||
|
RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config defines the configuration for the restore-backup tool.
|
||||||
|
type Config struct {
|
||||||
|
Name string `yaml:"name" validate:"required"`
|
||||||
|
Kind string `yaml:"kind" validate:"required"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
Source string `yaml:"source" validate:"required"`
|
||||||
|
AuthRequired []string `yaml:"authRequired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if !tools.Register(kind, newConfig) {
|
||||||
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||||
|
actual := Config{Name: name}
|
||||||
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolConfigKind returns the kind of the tool.
|
||||||
|
func (cfg Config) ToolConfigKind() string {
|
||||||
|
return kind
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize initializes the tool from the configuration.
|
||||||
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||||
|
rawS, ok := srcs[cfg.Source]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||||
|
}
|
||||||
|
s, ok := rawS.(compatibleSource)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
project := s.GetDefaultProject()
|
||||||
|
var targetProjectParam parameters.Parameter
|
||||||
|
if project != "" {
|
||||||
|
targetProjectParam = parameters.NewStringParameterWithDefault("target_project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||||
|
} else {
|
||||||
|
targetProjectParam = parameters.NewStringParameter("target_project", "The project ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
allParameters := parameters.Parameters{
|
||||||
|
targetProjectParam,
|
||||||
|
parameters.NewStringParameter("target_instance", "Cloud SQL instance ID of the target instance. This does not include the project ID."),
|
||||||
|
parameters.NewStringParameter("backup_id", "Identifier of the backup being restored. Can be a BackupRun ID, backup name, or BackupDR backup name. Use the full backup ID as provided, do not try to parse it"),
|
||||||
|
parameters.NewStringParameterWithRequired("source_project", "GCP project ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||||
|
parameters.NewStringParameterWithRequired("source_instance", "Cloud SQL instance ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||||
|
}
|
||||||
|
paramManifest := allParameters.Manifest()
|
||||||
|
|
||||||
|
description := cfg.Description
|
||||||
|
if description == "" {
|
||||||
|
description = "Restores a backup on a Cloud SQL instance."
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||||
|
|
||||||
|
return Tool{
|
||||||
|
Config: cfg,
|
||||||
|
AllParams: allParameters,
|
||||||
|
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||||
|
mcpManifest: mcpManifest,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool represents the restore-backup tool.
|
||||||
|
type Tool struct {
|
||||||
|
Config
|
||||||
|
AllParams parameters.Parameters `yaml:"allParams"`
|
||||||
|
manifest tools.Manifest
|
||||||
|
mcpManifest tools.McpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
|
return t.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
|
targetProject, ok := paramsMap["target_project"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'target_project' parameter: %v", paramsMap["target_project"])
|
||||||
|
}
|
||||||
|
targetInstance, ok := paramsMap["target_instance"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"])
|
||||||
|
}
|
||||||
|
backupID, ok := paramsMap["backup_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"])
|
||||||
|
}
|
||||||
|
sourceProject, _ := paramsMap["source_project"].(string)
|
||||||
|
sourceInstance, _ := paramsMap["source_instance"].(string)
|
||||||
|
|
||||||
|
return source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseParams parses the parameters for the tool.
|
||||||
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
return parameters.ParseParams(t.AllParams, data, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manifest returns the tool's manifest.
|
||||||
|
func (t Tool) Manifest() tools.Manifest {
|
||||||
|
return t.manifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpManifest returns the tool's MCP manifest.
|
||||||
|
func (t Tool) McpManifest() tools.McpManifest {
|
||||||
|
return t.mcpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorized checks if the tool is authorized.
|
||||||
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.UseClientAuthorization(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||||
|
return "Authorization", nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlrestorebackup_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFromYaml(t *testing.T) {
|
||||||
|
ctx, err := testutils.ContextWithNewLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
tcs := []struct {
|
||||||
|
desc string
|
||||||
|
in string
|
||||||
|
want server.ToolConfigs
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "basic example",
|
||||||
|
in: `
|
||||||
|
tools:
|
||||||
|
restore-backup-tool:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
description: a test description
|
||||||
|
source: a-source
|
||||||
|
`,
|
||||||
|
want: server.ToolConfigs{
|
||||||
|
"restore-backup-tool": cloudsqlrestorebackup.Config{
|
||||||
|
Name: "restore-backup-tool",
|
||||||
|
Kind: "cloud-sql-restore-backup",
|
||||||
|
Description: "a test description",
|
||||||
|
Source: "a-source",
|
||||||
|
AuthRequired: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
got := struct {
|
||||||
|
Tools server.ToolConfigs `yaml:"tools"`
|
||||||
|
}{}
|
||||||
|
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unmarshal: %s", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||||
|
t.Fatalf("incorrect parse: diff %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -147,12 +147,20 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semanti search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
|
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
|
|||||||
@@ -112,8 +112,7 @@ func TestHealthcareToolEndpoints(t *testing.T) {
|
|||||||
fhirStoreID := "fhir-store-" + uuid.New().String()
|
fhirStoreID := "fhir-store-" + uuid.New().String()
|
||||||
dicomStoreID := "dicom-store-" + uuid.New().String()
|
dicomStoreID := "dicom-store-" + uuid.New().String()
|
||||||
|
|
||||||
patient1ID, patient2ID, teardown := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
patient1ID, patient2ID := setupHealthcareResources(t, healthcareService, healthcareDataset, fhirStoreID, dicomStoreID)
|
||||||
defer teardown(t)
|
|
||||||
|
|
||||||
toolsFile := getToolsConfig(sourceConfig)
|
toolsFile := getToolsConfig(sourceConfig)
|
||||||
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
toolsFile = addClientAuthSourceConfig(t, toolsFile)
|
||||||
@@ -173,10 +172,8 @@ func TestHealthcareToolWithStoreRestriction(t *testing.T) {
|
|||||||
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
|
disallowedFHIRStoreID := "fhir-store-disallowed-" + uuid.New().String()
|
||||||
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
|
disallowedDICOMStoreID := "dicom-store-disallowed-" + uuid.New().String()
|
||||||
|
|
||||||
_, _, teardownAllowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
setupHealthcareResources(t, healthcareService, healthcareDataset, allowedFHIRStoreID, allowedDICOMStoreID)
|
||||||
defer teardownAllowedStores(t)
|
setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
||||||
_, _, teardownDisallowedStores := setupHealthcareResources(t, healthcareService, healthcareDataset, disallowedFHIRStoreID, disallowedDICOMStoreID)
|
|
||||||
defer teardownDisallowedStores(t)
|
|
||||||
|
|
||||||
// Configure source with dataset restriction.
|
// Configure source with dataset restriction.
|
||||||
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
|
sourceConfig["allowedFhirStores"] = []string{allowedFHIRStoreID}
|
||||||
@@ -257,7 +254,7 @@ func newHealthcareService(ctx context.Context) (*healthcare.Service, error) {
|
|||||||
return healthcareService, nil
|
return healthcareService, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string, func(*testing.T)) {
|
func setupHealthcareResources(t *testing.T, service *healthcare.Service, datasetID, fhirStoreID, dicomStoreID string) (string, string) {
|
||||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
|
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", healthcareProject, healthcareRegion, datasetID)
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -266,12 +263,24 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
|||||||
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
|
if fhirStore, err = service.Projects.Locations.Datasets.FhirStores.Create(datasetName, fhirStore).FhirStoreId(fhirStoreID).Do(); err != nil {
|
||||||
t.Fatalf("failed to create fhir store: %v", err)
|
t.Fatalf("failed to create fhir store: %v", err)
|
||||||
}
|
}
|
||||||
|
// Register cleanup
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
||||||
|
t.Logf("failed to delete fhir store: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Create DICOM store
|
// Create DICOM store
|
||||||
dicomStore := &healthcare.DicomStore{}
|
dicomStore := &healthcare.DicomStore{}
|
||||||
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
|
if dicomStore, err = service.Projects.Locations.Datasets.DicomStores.Create(datasetName, dicomStore).DicomStoreId(dicomStoreID).Do(); err != nil {
|
||||||
t.Fatalf("failed to create dicom store: %v", err)
|
t.Fatalf("failed to create dicom store: %v", err)
|
||||||
}
|
}
|
||||||
|
// Register cleanup
|
||||||
|
t.Cleanup(func() {
|
||||||
|
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
||||||
|
t.Logf("failed to delete dicom store: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Create Patient 1
|
// Create Patient 1
|
||||||
patient1Body := bytes.NewBuffer([]byte(`{
|
patient1Body := bytes.NewBuffer([]byte(`{
|
||||||
@@ -317,15 +326,7 @@ func setupHealthcareResources(t *testing.T, service *healthcare.Service, dataset
|
|||||||
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
|
createFHIRResource(t, service, fhirStore.Name, "Observation", observation2Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
teardown := func(t *testing.T) {
|
return patient1ID, patient2ID
|
||||||
if _, err := service.Projects.Locations.Datasets.FhirStores.Delete(fhirStore.Name).Do(); err != nil {
|
|
||||||
t.Logf("failed to delete fhir store: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := service.Projects.Locations.Datasets.DicomStores.Delete(dicomStore.Name).Do(); err != nil {
|
|
||||||
t.Logf("failed to delete dicom store: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return patient1ID, patient2ID, teardown
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getToolsConfig(sourceConfig map[string]any) map[string]any {
|
func getToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||||
|
|||||||
267
tests/cloudsql/cloud_sql_restore_backup_test.go
Normal file
267
tests/cloudsql/cloud_sql_restore_backup_test.go
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
|
"google.golang.org/api/sqladmin/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
restoreBackupToolKind = "cloud-sql-restore-backup"
|
||||||
|
)
|
||||||
|
|
||||||
|
type restoreBackupTransport struct {
|
||||||
|
transport http.RoundTripper
|
||||||
|
url *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *restoreBackupTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if strings.HasPrefix(req.URL.String(), "https://sqladmin.googleapis.com") {
|
||||||
|
req.URL.Scheme = t.url.Scheme
|
||||||
|
req.URL.Host = t.url.Host
|
||||||
|
}
|
||||||
|
return t.transport.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
type masterRestoreBackupHandler struct {
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *masterRestoreBackupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||||
|
h.t.Errorf("User-Agent header not found")
|
||||||
|
}
|
||||||
|
var body sqladmin.InstancesRestoreBackupRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||||
|
h.t.Fatalf("failed to decode request body: %v", err)
|
||||||
|
} else {
|
||||||
|
h.t.Logf("Received request body: %+v", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var expectedBody sqladmin.InstancesRestoreBackupRequest
|
||||||
|
var response any
|
||||||
|
var statusCode int
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case body.Backup != "":
|
||||||
|
expectedBody = sqladmin.InstancesRestoreBackupRequest{
|
||||||
|
Backup: "projects/p1/backups/test-uid",
|
||||||
|
}
|
||||||
|
response = map[string]any{"name": "op1", "status": "PENDING"}
|
||||||
|
statusCode = http.StatusOK
|
||||||
|
case body.BackupdrBackup != "":
|
||||||
|
expectedBody = sqladmin.InstancesRestoreBackupRequest{
|
||||||
|
BackupdrBackup: "projects/p1/locations/us-central1/backupVaults/test-vault/dataSources/test-ds/backups/test-uid",
|
||||||
|
}
|
||||||
|
response = map[string]any{"name": "op1", "status": "PENDING"}
|
||||||
|
statusCode = http.StatusOK
|
||||||
|
case body.RestoreBackupContext != nil:
|
||||||
|
expectedBody = sqladmin.InstancesRestoreBackupRequest{
|
||||||
|
RestoreBackupContext: &sqladmin.RestoreBackupContext{
|
||||||
|
Project: "p1",
|
||||||
|
InstanceId: "source",
|
||||||
|
BackupRunId: 12345,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
response = map[string]any{"name": "op1", "status": "PENDING"}
|
||||||
|
statusCode = http.StatusOK
|
||||||
|
default:
|
||||||
|
http.Error(w, fmt.Sprintf("unhandled restore request body: %v", body), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expectedBody, body); diff != "" {
|
||||||
|
h.t.Errorf("unexpected request body (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(statusCode)
|
||||||
|
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRestoreBackupToolEndpoints(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
handler := &masterRestoreBackupHandler{t: t}
|
||||||
|
server := httptest.NewServer(handler)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
serverURL, err := url.Parse(server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to parse server URL: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
originalTransport := http.DefaultClient.Transport
|
||||||
|
if originalTransport == nil {
|
||||||
|
originalTransport = http.DefaultTransport
|
||||||
|
}
|
||||||
|
http.DefaultClient.Transport = &restoreBackupTransport{
|
||||||
|
transport: originalTransport,
|
||||||
|
url: serverURL,
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
http.DefaultClient.Transport = originalTransport
|
||||||
|
})
|
||||||
|
|
||||||
|
var args []string
|
||||||
|
toolsFile := getRestoreBackupToolsConfig()
|
||||||
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("command initialization returned an error: %s", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
waitCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("toolbox command logs: \n%s", out)
|
||||||
|
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
toolName string
|
||||||
|
body string
|
||||||
|
want string
|
||||||
|
expectError bool
|
||||||
|
errorStatus int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful restore with standard backup",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"target_project": "p1", "target_instance": "instance-standard", "backup_id": "12345", "source_project": "p1", "source_instance": "source"}`,
|
||||||
|
want: `{"name":"op1","status":"PENDING"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful restore with project level backup",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "projects/p1/backups/test-uid"}`,
|
||||||
|
want: `{"name":"op1","status":"PENDING"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "successful restore with BackupDR backup",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "projects/p1/locations/us-central1/backupVaults/test-vault/dataSources/test-ds/backups/test-uid"}`,
|
||||||
|
want: `{"name":"op1","status":"PENDING"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing source instance info for standard backup",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"target_project": "p1", "target_instance": "instance-project-level", "backup_id": "12345"}`,
|
||||||
|
expectError: true,
|
||||||
|
errorStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing backup identifier",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"target_project": "p1", "target_instance": "instance-project-level"}`,
|
||||||
|
expectError: true,
|
||||||
|
errorStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing target instance info",
|
||||||
|
toolName: "restore-backup",
|
||||||
|
body: `{"backup_id": "12345"}`,
|
||||||
|
expectError: true,
|
||||||
|
errorStatus: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create request: %s", err)
|
||||||
|
}
|
||||||
|
req.Header.Add("Content-type", "application/json")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to send request: %s", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if tc.expectError {
|
||||||
|
if resp.StatusCode != tc.errorStatus {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("expected status %d but got %d: %s", tc.errorStatus, resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
Result string `json:"result"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got, want map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(result.Result), &got); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal result: %v", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal want: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("unexpected result: got %+v, want %+v", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRestoreBackupToolsConfig() map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"sources": map[string]any{
|
||||||
|
"my-cloud-sql-source": map[string]any{
|
||||||
|
"kind": "cloud-sql-admin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"tools": map[string]any{
|
||||||
|
"restore-backup": map[string]any{
|
||||||
|
"kind": restoreBackupToolKind,
|
||||||
|
"source": "my-cloud-sql-source",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -132,12 +132,20 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semantic search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -186,6 +194,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
|||||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||||
|
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test connection with different IP type
|
// Test connection with different IP type
|
||||||
|
|||||||
251
tests/embedding.go
Normal file
251
tests/embedding.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package tests contains end to end tests meant to verify the Toolbox Server
|
||||||
|
// works as expected when executed as a binary.
|
||||||
|
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
var apiKey = os.Getenv("API_KEY")
|
||||||
|
|
||||||
|
// AddSemanticSearchConfig adds embedding models and semantic search tools to the config
|
||||||
|
// with configurable tool kind and SQL statements.
|
||||||
|
func AddSemanticSearchConfig(t *testing.T, config map[string]any, toolKind, insertStmt, searchStmt string) map[string]any {
|
||||||
|
config["embeddingModels"] = map[string]any{
|
||||||
|
"gemini_model": map[string]any{
|
||||||
|
"kind": "gemini",
|
||||||
|
"model": "gemini-embedding-001",
|
||||||
|
"apiKey": apiKey,
|
||||||
|
"dimension": 768,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools, ok := config["tools"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("unable to get tools from config")
|
||||||
|
}
|
||||||
|
|
||||||
|
tools["insert_docs"] = map[string]any{
|
||||||
|
"kind": toolKind,
|
||||||
|
"source": "my-instance",
|
||||||
|
"description": "Stores content and its vector embedding into the documents table.",
|
||||||
|
"statement": insertStmt,
|
||||||
|
"parameters": []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "content",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content associated with the vector.",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"name": "text_to_embed",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content used to generate the vector.",
|
||||||
|
"embeddedBy": "gemini_model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tools["search_docs"] = map[string]any{
|
||||||
|
"kind": toolKind,
|
||||||
|
"source": "my-instance",
|
||||||
|
"description": "Finds the most semantically similar document to the query vector.",
|
||||||
|
"statement": searchStmt,
|
||||||
|
"parameters": []any{
|
||||||
|
map[string]any{
|
||||||
|
"name": "query",
|
||||||
|
"type": "string",
|
||||||
|
"description": "The text content to search for.",
|
||||||
|
"embeddedBy": "gemini_model",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config["tools"] = tools
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunSemanticSearchToolInvokeTest runs the insert_docs and search_docs tools
|
||||||
|
// via both HTTP and MCP endpoints and verifies the output.
|
||||||
|
func RunSemanticSearchToolInvokeTest(t *testing.T, insertWant, mcpInsertWant, searchWant string) {
|
||||||
|
// Initialize MCP session once for the MCP test cases
|
||||||
|
sessionId := RunInitialize(t, "2024-11-05")
|
||||||
|
|
||||||
|
tcs := []struct {
|
||||||
|
name string
|
||||||
|
api string
|
||||||
|
isMcp bool
|
||||||
|
requestBody interface{}
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "HTTP invoke insert_docs",
|
||||||
|
api: "http://127.0.0.1:5000/api/tool/insert_docs/invoke",
|
||||||
|
isMcp: false,
|
||||||
|
requestBody: `{"content": "The quick brown fox jumps over the lazy dog", "text_to_embed": "The quick brown fox jumps over the lazy dog"}`,
|
||||||
|
want: insertWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP invoke search_docs",
|
||||||
|
api: "http://127.0.0.1:5000/api/tool/search_docs/invoke",
|
||||||
|
isMcp: false,
|
||||||
|
requestBody: `{"query": "fast fox jumping"}`,
|
||||||
|
want: searchWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MCP invoke insert_docs",
|
||||||
|
api: "http://127.0.0.1:5000/mcp",
|
||||||
|
isMcp: true,
|
||||||
|
requestBody: jsonrpc.JSONRPCRequest{
|
||||||
|
Jsonrpc: "2.0",
|
||||||
|
Id: "mcp-insert-docs",
|
||||||
|
Request: jsonrpc.Request{
|
||||||
|
Method: "tools/call",
|
||||||
|
},
|
||||||
|
Params: map[string]any{
|
||||||
|
"name": "insert_docs",
|
||||||
|
"arguments": map[string]any{
|
||||||
|
"content": "The quick brown fox jumps over the lazy dog",
|
||||||
|
"text_to_embed": "The quick brown fox jumps over the lazy dog",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: mcpInsertWant,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MCP invoke search_docs",
|
||||||
|
api: "http://127.0.0.1:5000/mcp",
|
||||||
|
isMcp: true,
|
||||||
|
requestBody: jsonrpc.JSONRPCRequest{
|
||||||
|
Jsonrpc: "2.0",
|
||||||
|
Id: "mcp-search-docs",
|
||||||
|
Request: jsonrpc.Request{
|
||||||
|
Method: "tools/call",
|
||||||
|
},
|
||||||
|
Params: map[string]any{
|
||||||
|
"name": "search_docs",
|
||||||
|
"arguments": map[string]any{
|
||||||
|
"query": "fast fox jumping",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: searchWant,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
var bodyReader io.Reader
|
||||||
|
headers := map[string]string{}
|
||||||
|
|
||||||
|
// Prepare Request Body and Headers
|
||||||
|
if tc.isMcp {
|
||||||
|
reqBytes, err := json.Marshal(tc.requestBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal mcp request: %v", err)
|
||||||
|
}
|
||||||
|
bodyReader = bytes.NewBuffer(reqBytes)
|
||||||
|
if sessionId != "" {
|
||||||
|
headers["Mcp-Session-Id"] = sessionId
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bodyReader = bytes.NewBufferString(tc.requestBody.(string))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send Request
|
||||||
|
resp, respBody := RunRequest(t, http.MethodPost, tc.api, bodyReader, headers)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize Response to get the actual tool result string
|
||||||
|
var got string
|
||||||
|
if tc.isMcp {
|
||||||
|
var mcpResp struct {
|
||||||
|
Result struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"result"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(respBody, &mcpResp); err != nil {
|
||||||
|
t.Fatalf("error parsing mcp response: %s", err)
|
||||||
|
}
|
||||||
|
if len(mcpResp.Result.Content) > 0 {
|
||||||
|
got = mcpResp.Result.Content[0].Text
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var httpResp map[string]interface{}
|
||||||
|
if err := json.Unmarshal(respBody, &httpResp); err != nil {
|
||||||
|
t.Fatalf("error parsing http response: %s", err)
|
||||||
|
}
|
||||||
|
if res, ok := httpResp["result"].(string); ok {
|
||||||
|
got = res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(got, tc.want) {
|
||||||
|
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupPostgresVectorTable sets up the vector extension and a vector table
|
||||||
|
func SetupPostgresVectorTable(t *testing.T, ctx context.Context, pool *pgxpool.Pool) (string, func(*testing.T)) {
|
||||||
|
t.Helper()
|
||||||
|
if _, err := pool.Exec(ctx, "CREATE EXTENSION IF NOT EXISTS vector"); err != nil {
|
||||||
|
t.Fatalf("failed to create vector extension: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := "vector_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||||
|
|
||||||
|
createTableStmt := fmt.Sprintf(`CREATE TABLE %s (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
content TEXT,
|
||||||
|
embedding vector(768)
|
||||||
|
)`, tableName)
|
||||||
|
|
||||||
|
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
||||||
|
t.Fatalf("failed to create table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tableName, func(t *testing.T) {
|
||||||
|
if _, err := pool.Exec(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", tableName)); err != nil {
|
||||||
|
t.Errorf("failed to drop table %s: %v", tableName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetPostgresVectorSearchStmts(vectorTableName string) (string, string) {
|
||||||
|
insertStmt := fmt.Sprintf("INSERT INTO %s (content, embedding) VALUES ($1, $2)", vectorTableName)
|
||||||
|
searchStmt := fmt.Sprintf("SELECT id, content, embedding <-> $1 AS distance FROM %s ORDER BY distance LIMIT 1", vectorTableName)
|
||||||
|
return insertStmt, searchStmt
|
||||||
|
}
|
||||||
@@ -111,6 +111,10 @@ func TestPostgres(t *testing.T) {
|
|||||||
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||||
defer teardownTable2(t)
|
defer teardownTable2(t)
|
||||||
|
|
||||||
|
// Set up table for semantic search
|
||||||
|
vectorTableName, tearDownVectorTable := tests.SetupPostgresVectorTable(t, ctx, pool)
|
||||||
|
defer tearDownVectorTable(t)
|
||||||
|
|
||||||
// Write config into a file and pass it to command
|
// Write config into a file and pass it to command
|
||||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||||
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
toolsFile = tests.AddExecuteSqlConfig(t, toolsFile, "postgres-execute-sql")
|
||||||
@@ -118,6 +122,10 @@ func TestPostgres(t *testing.T) {
|
|||||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||||
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
|
||||||
|
|
||||||
|
// Add semantic search tool config
|
||||||
|
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
|
||||||
|
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
|
||||||
|
|
||||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("command initialization returned an error: %s", err)
|
t.Fatalf("command initialization returned an error: %s", err)
|
||||||
@@ -165,4 +173,5 @@ func TestPostgres(t *testing.T) {
|
|||||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||||
|
tests.RunSemanticSearchToolInvokeTest(t, "null", "", "The quick brown fox")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1240,7 +1240,10 @@ func RunPostgresListTablesTest(t *testing.T, tableNameParam, tableNameAuth, user
|
|||||||
var filteredGot []any
|
var filteredGot []any
|
||||||
for _, item := range got {
|
for _, item := range got {
|
||||||
if tableMap, ok := item.(map[string]interface{}); ok {
|
if tableMap, ok := item.(map[string]interface{}); ok {
|
||||||
if schema, ok := tableMap["schema_name"]; ok && schema == "public" {
|
name, _ := tableMap["object_name"].(string)
|
||||||
|
|
||||||
|
// Only keep the table if it matches expected test tables
|
||||||
|
if name == tableNameParam || name == tableNameAuth {
|
||||||
filteredGot = append(filteredGot, item)
|
filteredGot = append(filteredGot, item)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user