mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-17 03:18:00 -05:00
Compare commits
23 Commits
spanner-cr
...
healthcare
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ef303a4d5 | ||
|
|
96f13dd517 | ||
|
|
6e09b08c6a | ||
|
|
1f15a111f1 | ||
|
|
dfddeb528d | ||
|
|
00c3e6d8cb | ||
|
|
d00b6fdf18 | ||
|
|
4d23a3bbf2 | ||
|
|
5e0999ebf5 | ||
|
|
6b02591703 | ||
|
|
8e0fb03483 | ||
|
|
2817dd1e5d | ||
|
|
68a218407e | ||
|
|
d69792d843 | ||
|
|
647b04d3a7 | ||
|
|
030df9766f | ||
|
|
5dbf207162 | ||
|
|
9c3720e31d | ||
|
|
3cd3c39d66 | ||
|
|
0691a6f715 | ||
|
|
467b96a23b | ||
|
|
4abf0c39e7 | ||
|
|
dd7b9de623 |
@@ -59,6 +59,13 @@ You can manually trigger the bot by commenting on your Pull Request:
|
||||
* `/gemini summary`: Posts a summary of the changes in the pull request.
|
||||
* `/gemini help`: Overview of the available commands
|
||||
|
||||
## Guidelines for Pull Requests
|
||||
|
||||
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
|
||||
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
|
||||
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
|
||||
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
|
||||
|
||||
## Adding a New Database Source or Tool
|
||||
|
||||
Please create an
|
||||
@@ -110,6 +117,8 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
||||
We recommend looking at an [example tool
|
||||
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
||||
|
||||
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
|
||||
|
||||
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
|
||||
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
|
||||
Create a `Config` struct and a `Tool` struct to store necessary parameters for
|
||||
@@ -163,6 +172,8 @@ tools.
|
||||
parameters][temp-param-doc]. Only run this test if template
|
||||
parameters apply to your tool.
|
||||
|
||||
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
|
||||
|
||||
* **Add the new database to the integration test workflow** in
|
||||
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
||||
|
||||
@@ -244,4 +255,4 @@ resources.
|
||||
* **PR Description:** PR description should **always** be included. It should
|
||||
include a concise description of the changes, it's impact, along with a
|
||||
summary of the solution. If the PR is related to a specific issue, the issue
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||
|
||||
17
DEVELOPER.md
17
DEVELOPER.md
@@ -379,6 +379,23 @@ to approve PRs for main. TeamSync is used to create this team from the MDB
|
||||
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
|
||||
but aren't part of the core team should join this group.
|
||||
|
||||
### Issue/PR Triage and SLO
|
||||
After an issue is created, maintainers will assign the following labels:
|
||||
* `Priority` (defaulted to P0)
|
||||
* `Type` (if applicable)
|
||||
* `Product` (if applicable)
|
||||
|
||||
All incoming issues and PRs will follow the following SLO:
|
||||
| Type | Priority | Objective |
|
||||
|-----------------|----------|------------------------------------------------------------------------|
|
||||
| Feature Request | P0 | Must respond within **5 days** |
|
||||
| Process | P0 | Must respond within **5 days** |
|
||||
| Bugs | P0 | Must respond within **5 days**, and resolve/closure within **14 days** |
|
||||
| Bugs | P1 | Must respond within **7 days**, and resolve/closure within **90 days** |
|
||||
| Bugs | P2 | Must respond within **30 days**
|
||||
|
||||
_Types that are not listed in the table do not adhere to any SLO._
|
||||
|
||||
### Releasing
|
||||
|
||||
Toolbox has two types of releases: versioned and continuous. It uses Google
|
||||
|
||||
@@ -272,7 +272,7 @@ To run Toolbox from binary:
|
||||
To run the server after pulling the [container image](#installing-the-server):
|
||||
|
||||
```sh
|
||||
export VERSION=0.11.0 # Use the version you pulled
|
||||
export VERSION=0.24.0 # Use the version you pulled
|
||||
docker run -p 5000:5000 \
|
||||
-v $(pwd)/tools.yaml:/app/tools.yaml \
|
||||
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
|
||||
|
||||
@@ -92,11 +92,13 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
||||
|
||||
@@ -1493,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_postgres_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1503,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mysql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1513,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_mssql_admin_tools",
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -48,11 +48,13 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
|
||||
## Install MCP Toolbox
|
||||
|
||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,11 +48,13 @@ database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
|
||||
## Install MCP Toolbox
|
||||
|
||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -48,11 +48,13 @@ instance, database and users:
|
||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
|
||||
## Install MCP Toolbox
|
||||
|
||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
||||
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
{{< notice note >}}
|
||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||
|
||||
@@ -20,6 +20,7 @@ The native SDKs can be combined with MCP clients in many cases.
|
||||
|
||||
Toolbox currently supports the following versions of MCP specification:
|
||||
|
||||
* [2025-11-25](https://modelcontextprotocol.io/specification/2025-11-25)
|
||||
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
|
||||
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
|
||||
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)
|
||||
|
||||
@@ -207,6 +207,7 @@ You can connect to Toolbox Cloud Run instances directly through the SDK.
|
||||
{{< tab header="Python" lang="python" >}}
|
||||
import asyncio
|
||||
from toolbox_core import ToolboxClient, auth_methods
|
||||
from toolbox_core.protocol import Protocol
|
||||
|
||||
# Replace with the Cloud Run service URL generated in the previous step
|
||||
URL = "https://cloud-run-url.app"
|
||||
@@ -217,6 +218,7 @@ async def main():
|
||||
async with ToolboxClient(
|
||||
URL,
|
||||
client_headers={"Authorization": auth_token_provider},
|
||||
protocol=Protocol.TOOLBOX,
|
||||
) as toolbox:
|
||||
toolset = await toolbox.load_toolset()
|
||||
# ...
|
||||
@@ -281,3 +283,5 @@ contain the specific error message needed to diagnose the problem.
|
||||
Manager, it means the Toolbox service account is missing permissions.
|
||||
- Ensure the `toolbox-identity` service account has the **Secret Manager
|
||||
Secret Accessor** (`roles/secretmanager.secretAccessor`) IAM role.
|
||||
|
||||
- **Cloud Run Connections via IAP:** Currently we do not support Cloud Run connections via [IAP](https://docs.cloud.google.com/iap/docs/concepts-overview). Please disable IAP if you are using it.
|
||||
@@ -187,12 +187,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
|
||||
* **Tools:**
|
||||
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
|
||||
@@ -203,6 +205,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for PostgreSQL
|
||||
|
||||
@@ -275,12 +279,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
* **Tools:**
|
||||
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
|
||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||
@@ -290,6 +296,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
## Cloud SQL for SQL Server
|
||||
|
||||
@@ -336,12 +344,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
manage existing resources.
|
||||
* All `viewer` tools
|
||||
* `create_database`
|
||||
* `create_backup`
|
||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||
all resources.
|
||||
* All `editor` and `viewer` tools
|
||||
* `create_instance`
|
||||
* `create_user`
|
||||
* `clone_instance`
|
||||
* `restore_backup`
|
||||
* **Tools:**
|
||||
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
|
||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||
@@ -351,6 +361,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
||||
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||
|
||||
## Dataplex
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
Initialize a BigQuery source that uses the client's access token:
|
||||
@@ -153,6 +154,7 @@ sources:
|
||||
# scopes: # Optional: List of OAuth scopes to request.
|
||||
# - "https://www.googleapis.com/auth/bigquery"
|
||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -167,3 +169,4 @@ sources:
|
||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
||||
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
||||
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |
|
||||
|
||||
@@ -91,8 +91,8 @@ visible to the LLM.
|
||||
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
|
||||
|
||||
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
|
||||
before running this tool. You can do so by running this command in the AlloyDB
|
||||
studio:
|
||||
to utilize PSV feature (`nlConfigParameters`) with this tool. You can do so by
|
||||
running this command in the AlloyDB studio:
|
||||
|
||||
```sql
|
||||
CREATE EXTENSION IF NOT EXISTS parameterized_views;
|
||||
|
||||
@@ -12,6 +12,9 @@ aliases:
|
||||
|
||||
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
||||
|
||||
> [!NOTE]
|
||||
> Only `alloydb`, `spannerReference`, and `cloudSqlReference` are supported as [datasource references](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1beta/projects.locations.dataAgents#DatasourceReferences).
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
@@ -41,13 +44,13 @@ tools:
|
||||
|
||||
### Usage Flow
|
||||
|
||||
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
|
||||
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
||||
|
||||
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
||||
|
||||
**Example Input Prompt:**
|
||||
**Example Input Query:**
|
||||
|
||||
```text
|
||||
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
||||
|
||||
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
@@ -0,0 +1,45 @@
|
||||
---
|
||||
title: cloud-sql-create-backup
|
||||
type: docs
|
||||
weight: 10
|
||||
description: "Creates a backup on a Cloud SQL instance."
|
||||
---
|
||||
|
||||
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||
|
||||
{{< notice info dd>}}
|
||||
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||
{{< /notice >}}
|
||||
|
||||
## Examples
|
||||
|
||||
Basic backup creation (current state)
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
backup-creation-basic:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
description: "Creates a backup on the given Cloud SQL instance."
|
||||
```
|
||||
## Reference
|
||||
### Tool Configuration
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-create-backup". |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| description | string | false | A description of the tool. |
|
||||
|
||||
### Tool Inputs
|
||||
|
||||
| **parameter** | **type** | **required** | **description** |
|
||||
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
|
||||
| project | string | true | The project ID. |
|
||||
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
|
||||
| location | string | false | (Optional) Location of the backup run. |
|
||||
| backup_description | string | false | (Optional) The description of this backup run. |
|
||||
|
||||
## See Also
|
||||
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)
|
||||
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
@@ -0,0 +1,53 @@
|
||||
---
|
||||
title: cloud-sql-restore-backup
|
||||
type: docs
|
||||
weight: 10
|
||||
description: "Restores a backup of a Cloud SQL instance."
|
||||
---
|
||||
|
||||
The `cloud-sql-restore-backup` tool restores a backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||
|
||||
{{< notice info dd>}}
|
||||
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||
{{< /notice >}}
|
||||
|
||||
## Examples
|
||||
|
||||
Basic backup restore
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
backup-restore-basic:
|
||||
kind: cloud-sql-restore-backup
|
||||
source: cloud-sql-admin-source
|
||||
description: "Restores a backup onto the given Cloud SQL instance."
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
### Tool Configuration
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ------------------------------------------------ |
|
||||
| kind | string | true | Must be "cloud-sql-restore-backup". |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| description | string | false | A description of the tool. |
|
||||
|
||||
### Tool Inputs
|
||||
|
||||
| **parameter** | **type** | **required** | **description** |
|
||||
| ------------------| :------: | :----------: | -----------------------------------------------------------------------------|
|
||||
| target_project | string | true | The project ID of the instance to restore the backup onto. |
|
||||
| target_instance | string | true | The instance to restore the backup onto. Does not include the project ID. |
|
||||
| backup_id | string | true | The identifier of the backup being restored. |
|
||||
| source_project | string | false | (Optional) The project ID of the instance that the backup belongs to. |
|
||||
| source_instance | string | false | (Optional) Cloud SQL instance ID of the instance that the backup belongs to. |
|
||||
|
||||
## Usage Notes
|
||||
|
||||
- The `backup_id` field can be a BackupRun ID (which will be an int64), backup name, or BackupDR backup name.
|
||||
- If the `backup_id` field contains a BackupRun ID (i.e. an int64), the optional fields `source_project` and `source_instance` must also be provided.
|
||||
|
||||
## See Also
|
||||
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||
- [Cloud SQL Restore API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/restoring)
|
||||
@@ -30,6 +30,10 @@ following config for example:
|
||||
- name: userNames
|
||||
type: array
|
||||
description: The user names to be set.
|
||||
items:
|
||||
name: userName # the item name doesn't matter but it has to exist
|
||||
type: string
|
||||
description: username
|
||||
```
|
||||
|
||||
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command
|
||||
|
||||
7
docs/en/samples/neo4j/_index.md
Normal file
7
docs/en/samples/neo4j/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Neo4j"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to get started with Toolbox using Neo4j.
|
||||
---
|
||||
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
@@ -0,0 +1,141 @@
|
||||
---
|
||||
title: "Quickstart (MCP with Neo4j)"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
How to get started running Toolbox with MCP Inspector and Neo4j as the source.
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol that standardizes how applications provide context to LLMs. Check out this page on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
|
||||
|
||||
|
||||
## Step 1: Set up your Neo4j Database and Data
|
||||
|
||||
In this section, you'll set up a database and populate it with sample data for a movies-related agent. This guide assumes you have a running Neo4j instance, either locally or in the cloud.
|
||||
|
||||
. **Populate the database with data.**
|
||||
To make this quickstart straightforward, we'll use the built-in Movies dataset available in Neo4j.
|
||||
|
||||
. In your Neo4j Browser, run the following command to create and populate the database:
|
||||
+
|
||||
```cypher
|
||||
:play movies
|
||||
````
|
||||
|
||||
. Follow the instructions to load the data. This will create a graph with `Movie`, `Person`, and `Actor` nodes and their relationships.
|
||||
|
||||
|
||||
## Step 2: Install and configure Toolbox
|
||||
|
||||
In this section, we will install the MCP Toolbox, configure our tools in a `tools.yaml` file, and then run the Toolbox server.
|
||||
|
||||
. **Install the Toolbox binary.**
|
||||
The simplest way to get started is to download the latest binary for your operating system.
|
||||
|
||||
. Download the latest version of Toolbox as a binary:
|
||||
\+
|
||||
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O [https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox](https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox)
|
||||
```
|
||||
|
||||
+
|
||||
. Make the binary executable:
|
||||
\+
|
||||
|
||||
```bash
|
||||
chmod +x toolbox
|
||||
```
|
||||
|
||||
. **Create the `tools.yaml` file.**
|
||||
This file defines your Neo4j source and the specific tools that will be exposed to your AI agent.
|
||||
\+
|
||||
{{\< notice tip \>}}
|
||||
Authentication for the Neo4j source uses standard username and password fields. For production use, it is highly recommended to use environment variables for sensitive information like passwords.
|
||||
{{\< /notice \>}}
|
||||
\+
|
||||
Write the following into a `tools.yaml` file:
|
||||
\+
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-neo4j-source:
|
||||
kind: neo4j
|
||||
uri: bolt://localhost:7687
|
||||
user: neo4j
|
||||
password: my-password # Replace with your actual password
|
||||
|
||||
tools:
|
||||
search-movies-by-actor:
|
||||
kind: neo4j-cypher
|
||||
source: my-neo4j-source
|
||||
description: "Searches for movies an actor has appeared in based on their name. Useful for questions like 'What movies has Tom Hanks been in?'"
|
||||
parameters:
|
||||
- name: actor_name
|
||||
type: string
|
||||
description: The full name of the actor to search for.
|
||||
statement: |
|
||||
MATCH (p:Person {name: $actor_name}) -[:ACTED_IN]-> (m:Movie)
|
||||
RETURN m.title AS title, m.year AS year, m.genre AS genre
|
||||
|
||||
get-actor-for-movie:
|
||||
kind: neo4j-cypher
|
||||
source: my-neo4j-source
|
||||
description: "Finds the actors who starred in a specific movie. Useful for questions like 'Who acted in Inception?'"
|
||||
parameters:
|
||||
- name: movie_title
|
||||
type: string
|
||||
description: The exact title of the movie.
|
||||
statement: |
|
||||
MATCH (p:Person) -[:ACTED_IN]-> (m:Movie {title: $movie_title})
|
||||
RETURN p.name AS actor
|
||||
```
|
||||
|
||||
. **Start the Toolbox server.**
|
||||
Run the Toolbox server, pointing to the `tools.yaml` file you created earlier.
|
||||
\+
|
||||
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
|
||||
## Step 3: Connect to MCP Inspector
|
||||
|
||||
. **Run the MCP Inspector:**
|
||||
\+
|
||||
|
||||
```bash
|
||||
npx @modelcontextprotocol/inspector
|
||||
```
|
||||
|
||||
. Type `y` when it asks to install the inspector package.
|
||||
. It should show the following when the MCP Inspector is up and running (please take note of `<YOUR_SESSION_TOKEN>`):
|
||||
\+
|
||||
|
||||
```bash
|
||||
Starting MCP inspector...
|
||||
⚙️ Proxy server listening on localhost:6277
|
||||
🔑 Session token: <YOUR_SESSION_TOKEN>
|
||||
Use this token to authenticate requests or set DANGEROUSLY_OMIT_AUTH=true to disable auth
|
||||
|
||||
🚀 MCP Inspector is up and running at:
|
||||
http://localhost:6274/?MCP_PROXY_AUTH_TOKEN=<YOUR_SESSION_TOKEN>
|
||||
```
|
||||
|
||||
1. Open the above link in your browser.
|
||||
|
||||
1. For `Transport Type`, select `Streamable HTTP`.
|
||||
|
||||
1. For `URL`, type in `http://127.0.0.1:5000/mcp`.
|
||||
|
||||
1. For `Configuration` -\> `Proxy Session Token`, make sure `<YOUR_SESSION_TOKEN>` is present.
|
||||
|
||||
1. Click `Connect`.
|
||||
|
||||
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
|
||||
|
||||
1. Test out your tools here\!
|
||||
|
||||
@@ -19,6 +19,7 @@ sources:
|
||||
location: ${BIGQUERY_LOCATION:}
|
||||
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
||||
scopes: ${BIGQUERY_SCOPES:}
|
||||
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
|
||||
|
||||
tools:
|
||||
analyze_contribution:
|
||||
|
||||
@@ -43,6 +43,12 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
restore_backup:
|
||||
kind: cloud-sql-restore-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mssql_admin_tools:
|
||||
@@ -54,3 +60,5 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
- restore_backup
|
||||
|
||||
@@ -43,6 +43,12 @@ tools:
|
||||
clone_instance:
|
||||
kind: cloud-sql-clone-instance
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
restore_backup:
|
||||
kind: cloud-sql-restore-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_mysql_admin_tools:
|
||||
@@ -54,3 +60,5 @@ toolsets:
|
||||
- create_user
|
||||
- wait_for_operation
|
||||
- clone_instance
|
||||
- create_backup
|
||||
- restore_backup
|
||||
|
||||
@@ -46,6 +46,12 @@ tools:
|
||||
postgres_upgrade_precheck:
|
||||
kind: postgres-upgrade-precheck
|
||||
source: cloud-sql-admin-source
|
||||
create_backup:
|
||||
kind: cloud-sql-create-backup
|
||||
source: cloud-sql-admin-source
|
||||
restore_backup:
|
||||
kind: cloud-sql-restore-backup
|
||||
source: cloud-sql-admin-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_postgres_admin_tools:
|
||||
@@ -58,3 +64,5 @@ toolsets:
|
||||
- wait_for_operation
|
||||
- postgres_upgrade_precheck
|
||||
- clone_instance
|
||||
- create_backup
|
||||
- restore_backup
|
||||
|
||||
@@ -27,19 +27,21 @@ import (
|
||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
|
||||
v20251125 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20251125"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
|
||||
// Update the version used in InitializeResponse when this value is updated.
|
||||
const LATEST_PROTOCOL_VERSION = v20250618.PROTOCOL_VERSION
|
||||
const LATEST_PROTOCOL_VERSION = v20251125.PROTOCOL_VERSION
|
||||
|
||||
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
|
||||
var SUPPORTED_PROTOCOL_VERSIONS = []string{
|
||||
v20241105.PROTOCOL_VERSION,
|
||||
v20250326.PROTOCOL_VERSION,
|
||||
v20250618.PROTOCOL_VERSION,
|
||||
v20251125.PROTOCOL_VERSION,
|
||||
}
|
||||
|
||||
// InitializeResponse runs capability negotiation and protocol version agreement.
|
||||
@@ -102,6 +104,8 @@ func NotificationHandler(ctx context.Context, body []byte) error {
|
||||
// This is the Operation phase of the lifecycle for MCP client-server connections.
|
||||
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||
switch mcpVersion {
|
||||
case v20251125.PROTOCOL_VERSION:
|
||||
return v20251125.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
|
||||
case v20250618.PROTOCOL_VERSION:
|
||||
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
|
||||
case v20250326.PROTOCOL_VERSION:
|
||||
|
||||
326
internal/server/mcp/v20251125/method.go
Normal file
326
internal/server/mcp/v20251125/method.go
Normal file
@@ -0,0 +1,326 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20251125
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
// ProcessMethod returns a response for the request.
|
||||
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||
switch method {
|
||||
case PING:
|
||||
return pingHandler(id)
|
||||
case TOOLS_LIST:
|
||||
return toolsListHandler(id, toolset, body)
|
||||
case TOOLS_CALL:
|
||||
return toolsCallHandler(ctx, id, resourceMgr, body, header)
|
||||
case PROMPTS_LIST:
|
||||
return promptsListHandler(ctx, id, promptset, body)
|
||||
case PROMPTS_GET:
|
||||
return promptsGetHandler(ctx, id, resourceMgr, body)
|
||||
default:
|
||||
err := fmt.Errorf("invalid method %s", method)
|
||||
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||
}
|
||||
}
|
||||
|
||||
// pingHandler handles the "ping" method by returning an empty response.
|
||||
func pingHandler(id jsonrpc.RequestId) (any, error) {
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: struct{}{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
|
||||
var req ListToolsRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := ListToolsResult{
|
||||
Tools: toolset.McpManifest,
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// toolsCallHandler generate a response for tools call.
|
||||
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||
authServices := resourceMgr.GetAuthServiceMap()
|
||||
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var req CallToolRequest
|
||||
if err = json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
toolName := req.Params.Name
|
||||
toolArgument := req.Params.Arguments
|
||||
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||
tool, ok := resourceMgr.GetTool(toolName)
|
||||
if !ok {
|
||||
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Get access token
|
||||
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||
|
||||
// Check if this specific tool requires the standard authorization header
|
||||
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||
}
|
||||
if clientAuth {
|
||||
if accessToken == "" {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||
}
|
||||
}
|
||||
|
||||
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||
aMarshal, err := json.Marshal(toolArgument)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Tool authentication
|
||||
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||
claimsFromAuth := make(map[string]map[string]any)
|
||||
|
||||
// if using stdio, header will be nil and auth will not be supported
|
||||
if header != nil {
|
||||
for _, aS := range authServices {
|
||||
claims, err := aS.GetClaimsFromHeader(ctx, header)
|
||||
if err != nil {
|
||||
logger.DebugContext(ctx, err.Error())
|
||||
continue
|
||||
}
|
||||
if claims == nil {
|
||||
// authService not present in header
|
||||
continue
|
||||
}
|
||||
claimsFromAuth[aS.GetName()] = claims
|
||||
}
|
||||
}
|
||||
|
||||
// Tool authorization check
|
||||
verifiedAuthServices := make([]string, len(claimsFromAuth))
|
||||
i := 0
|
||||
for k := range claimsFromAuth {
|
||||
verifiedAuthServices[i] = k
|
||||
i++
|
||||
}
|
||||
|
||||
// Check if any of the specified auth services is verified
|
||||
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||
if !isAuthorized {
|
||||
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "tool invocation authorized")
|
||||
|
||||
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||
|
||||
// run tool invocation and generate response.
|
||||
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
// Missing authService tokens.
|
||||
if errors.Is(err, util.ErrUnauthorized) {
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Upstream auth error
|
||||
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||
if clientAuth {
|
||||
// Error with client credentials should pass down to the client
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
// Auth error with ADC should raise internal 500 error
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
text := TextContent{
|
||||
Type: "text",
|
||||
Text: err.Error(),
|
||||
}
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||
}, nil
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
if !ok {
|
||||
sliceRes = []any{results}
|
||||
}
|
||||
|
||||
for _, d := range sliceRes {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||
} else {
|
||||
text.Text = string(dM)
|
||||
}
|
||||
content = append(content, text)
|
||||
}
|
||||
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: CallToolResult{Content: content},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptsListHandler handles the "prompts/list" method.
|
||||
func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, body []byte) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "handling prompts/list request")
|
||||
|
||||
var req ListPromptsRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp prompts list request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
result := ListPromptsResult{
|
||||
Prompts: promptset.McpManifest,
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("returning %d prompts", len(promptset.McpManifest)))
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptsGetHandler handles the "prompts/get" method.
|
||||
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
|
||||
// retrieve logger from context
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "handling prompts/get request")
|
||||
|
||||
var req GetPromptRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
err = fmt.Errorf("invalid mcp prompts/get request: %w", err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||
}
|
||||
|
||||
promptName := req.Params.Name
|
||||
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||
if !ok {
|
||||
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Parse the arguments provided in the request.
|
||||
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid arguments for prompt %q: %w", promptName, err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("parsed args: %v", argValues))
|
||||
|
||||
// Substitute the argument values into the prompt's messages.
|
||||
substituted, err := prompt.SubstituteParams(argValues)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error substituting params for prompt %q: %w", promptName, err)
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
|
||||
// Cast the result to the expected []prompts.Message type.
|
||||
substitutedMessages, ok := substituted.([]prompts.Message)
|
||||
if !ok {
|
||||
err = fmt.Errorf("internal error: SubstituteParams returned unexpected type")
|
||||
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||
}
|
||||
logger.DebugContext(ctx, "substituted params successfully")
|
||||
|
||||
// Format the response messages into the required structure.
|
||||
promptMessages := make([]PromptMessage, len(substitutedMessages))
|
||||
for i, msg := range substitutedMessages {
|
||||
promptMessages[i] = PromptMessage{
|
||||
Role: msg.Role,
|
||||
Content: TextContent{
|
||||
Type: "text",
|
||||
Text: msg.Content,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
result := GetPromptResult{
|
||||
Description: prompt.Manifest().Description,
|
||||
Messages: promptMessages,
|
||||
}
|
||||
|
||||
return jsonrpc.JSONRPCResponse{
|
||||
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||
Id: id,
|
||||
Result: result,
|
||||
}, nil
|
||||
}
|
||||
219
internal/server/mcp/v20251125/types.go
Normal file
219
internal/server/mcp/v20251125/types.go
Normal file
@@ -0,0 +1,219 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package v20251125
|
||||
|
||||
import (
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// SERVER_NAME is the server name used in Implementation.
|
||||
const SERVER_NAME = "Toolbox"
|
||||
|
||||
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
|
||||
const PROTOCOL_VERSION = "2025-11-25"
|
||||
|
||||
// methods that are supported.
|
||||
const (
|
||||
PING = "ping"
|
||||
TOOLS_LIST = "tools/list"
|
||||
TOOLS_CALL = "tools/call"
|
||||
PROMPTS_LIST = "prompts/list"
|
||||
PROMPTS_GET = "prompts/get"
|
||||
)
|
||||
|
||||
/* Empty result */
|
||||
|
||||
// EmptyResult represents a response that indicates success but carries no data.
|
||||
type EmptyResult jsonrpc.Result
|
||||
|
||||
/* Pagination */
|
||||
|
||||
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||
type Cursor string
|
||||
|
||||
type PaginatedRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
// An opaque token representing the current pagination position.
|
||||
// If provided, the server should return results starting after this cursor.
|
||||
Cursor Cursor `json:"cursor,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type PaginatedResult struct {
|
||||
jsonrpc.Result
|
||||
// An opaque token representing the pagination position after the last returned result.
|
||||
// If present, there may be more results available.
|
||||
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||
}
|
||||
|
||||
/* Tools */
|
||||
|
||||
// Sent from the client to request a list of tools the server has.
|
||||
type ListToolsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a tools/list request from the client.
|
||||
type ListToolsResult struct {
|
||||
PaginatedResult
|
||||
Tools []tools.McpManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// Used by the client to invoke a tool provided by the server.
|
||||
type CallToolRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
// The sender or recipient of messages and data in a conversation.
|
||||
type Role string
|
||||
|
||||
const (
|
||||
RoleUser Role = "user"
|
||||
RoleAssistant Role = "assistant"
|
||||
)
|
||||
|
||||
// Base for objects that include optional annotations for the client.
|
||||
// The client can use annotations to inform how objects are used or displayed
|
||||
type Annotated struct {
|
||||
Annotations *struct {
|
||||
// Describes who the intended customer of this object or data is.
|
||||
// It can include multiple entries to indicate content useful for multiple
|
||||
// audiences (e.g., `["user", "assistant"]`).
|
||||
Audience []Role `json:"audience,omitempty"`
|
||||
// Describes how important this data is for operating the server.
|
||||
//
|
||||
// A value of 1 means "most important," and indicates that the data is
|
||||
// effectively required, while 0 means "least important," and indicates that
|
||||
// the data is entirely optional.
|
||||
//
|
||||
// @TJS-type number
|
||||
// @minimum 0
|
||||
// @maximum 1
|
||||
Priority float64 `json:"priority,omitempty"`
|
||||
} `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// TextContent represents text provided to or from an LLM.
|
||||
type TextContent struct {
|
||||
Annotated
|
||||
Type string `json:"type"`
|
||||
// The text content of the message.
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// The server's response to a tool call.
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
type CallToolResult struct {
|
||||
jsonrpc.Result
|
||||
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||
// For Toolbox, we will only be sending TextContent
|
||||
Content []TextContent `json:"content"`
|
||||
// Whether the tool call ended in an error.
|
||||
// If not set, this is assumed to be false (the call was successful).
|
||||
//
|
||||
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||
// and self-correct.
|
||||
//
|
||||
// However, any errors in _finding_ the tool, an error indicating that the
|
||||
// server does not support tool calls, or any other exceptional conditions,
|
||||
// should be reported as an MCP error response.
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
// An optional JSON object that represents the structured result of the tool call.
|
||||
StructuredContent map[string]any `json:"structuredContent,omitempty"`
|
||||
}
|
||||
|
||||
// Additional properties describing a Tool to clients.
|
||||
//
|
||||
// NOTE: all properties in ToolAnnotations are **hints**.
|
||||
// They are not guaranteed to provide a faithful description of
|
||||
// tool behavior (including descriptive properties like `title`).
|
||||
//
|
||||
// Clients should never make tool use decisions based on ToolAnnotations
|
||||
// received from untrusted servers.
|
||||
type ToolAnnotations struct {
|
||||
// A human-readable title for the tool.
|
||||
Title string `json:"title,omitempty"`
|
||||
// If true, the tool does not modify its environment.
|
||||
// Default: false
|
||||
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
|
||||
// If true, the tool may perform destructive updates to its environment.
|
||||
// If false, the tool performs only additive updates.
|
||||
// (This property is meaningful only when `readOnlyHint == false`)
|
||||
// Default: true
|
||||
DestructiveHint bool `json:"destructiveHint,omitempty"`
|
||||
// If true, calling the tool repeatedly with the same arguments
|
||||
// will have no additional effect on the its environment.
|
||||
// (This property is meaningful only when `readOnlyHint == false`)
|
||||
// Default: false
|
||||
IdempotentHint bool `json:"idempotentHint,omitempty"`
|
||||
// If true, this tool may interact with an "open world" of external
|
||||
// entities. If false, the tool's domain of interaction is closed.
|
||||
// For example, the world of a web search tool is open, whereas that
|
||||
// of a memory tool is not.
|
||||
// Default: true
|
||||
OpenWorldHint bool `json:"openWorldHint,omitempty"`
|
||||
}
|
||||
|
||||
/* Prompts */
|
||||
|
||||
// Sent from the client to request a list of prompts the server has.
|
||||
type ListPromptsRequest struct {
|
||||
PaginatedRequest
|
||||
}
|
||||
|
||||
// The server's response to a prompts/list request from the client.
|
||||
type ListPromptsResult struct {
|
||||
PaginatedResult
|
||||
Prompts []prompts.McpManifest `json:"prompts"`
|
||||
}
|
||||
|
||||
// Used by the client to get a prompt provided by the server.
|
||||
type GetPromptRequest struct {
|
||||
jsonrpc.Request
|
||||
Params struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
} `json:"params"`
|
||||
}
|
||||
|
||||
// The server's response to a prompts/get request from the client.
|
||||
type GetPromptResult struct {
|
||||
jsonrpc.Result
|
||||
Description string `json:"description,omitempty"`
|
||||
Messages []PromptMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// Describes a message returned as part of a prompt.
|
||||
type PromptMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content TextContent `json:"content"`
|
||||
}
|
||||
@@ -37,6 +37,7 @@ const jsonrpcVersion = "2.0"
|
||||
const protocolVersion20241105 = "2024-11-05"
|
||||
const protocolVersion20250326 = "2025-03-26"
|
||||
const protocolVersion20250618 = "2025-06-18"
|
||||
const protocolVersion20251125 = "2025-11-25"
|
||||
const serverName = "Toolbox"
|
||||
|
||||
var basicInputSchema = map[string]any{
|
||||
@@ -485,6 +486,23 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "version 2025-11-25",
|
||||
protocol: protocolVersion20251125,
|
||||
idHeader: false,
|
||||
initWant: map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "mcp-initialize",
|
||||
"result": map[string]any{
|
||||
"protocolVersion": "2025-11-25",
|
||||
"capabilities": map[string]any{
|
||||
"tools": map[string]any{"listChanged": false},
|
||||
"prompts": map[string]any{"listChanged": false},
|
||||
},
|
||||
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, vtc := range versTestCases {
|
||||
t.Run(vtc.name, func(t *testing.T) {
|
||||
@@ -494,8 +512,7 @@ func TestMcpEndpoint(t *testing.T) {
|
||||
if sessionId != "" {
|
||||
header["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
if vtc.protocol == protocolVersion20250618 {
|
||||
if vtc.protocol != protocolVersion20241105 && vtc.protocol != protocolVersion20250326 {
|
||||
header["MCP-Protocol-Version"] = vtc.protocol
|
||||
}
|
||||
|
||||
|
||||
@@ -304,10 +304,14 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, hasWildcard := allowedHosts["*"]
|
||||
_, hostIsAllowed := allowedHosts[r.Host]
|
||||
hostname := r.Host
|
||||
if host, _, err := net.SplitHostPort(r.Host); err == nil {
|
||||
hostname = host
|
||||
}
|
||||
_, hostIsAllowed := allowedHosts[hostname]
|
||||
if !hasWildcard && !hostIsAllowed {
|
||||
// Return 400 Bad Request or 403 Forbidden to block the attack
|
||||
http.Error(w, "Invalid Host header", http.StatusBadRequest)
|
||||
// Return 403 Forbidden to block the attack
|
||||
http.Error(w, "Invalid Host header", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
@@ -406,7 +410,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
}
|
||||
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
|
||||
for _, h := range cfg.AllowedHosts {
|
||||
allowedHostsMap[h] = struct{}{}
|
||||
hostname := h
|
||||
if host, _, err := net.SplitHostPort(h); err == nil {
|
||||
hostname = host
|
||||
}
|
||||
allowedHostsMap[hostname] = struct{}{}
|
||||
}
|
||||
r.Use(hostCheck(allowedHostsMap))
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ type Config struct {
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||
Scopes StringOrStringSlice `yaml:"scopes"`
|
||||
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
|
||||
}
|
||||
|
||||
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
||||
@@ -127,6 +128,10 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
r.WriteMode = WriteModeAllowed
|
||||
}
|
||||
|
||||
if r.MaxQueryResultRows == 0 {
|
||||
r.MaxQueryResultRows = 50
|
||||
}
|
||||
|
||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||
// The protected mode only allows write operations to the session's temporary datasets.
|
||||
// when using client OAuth, a new session is created every
|
||||
@@ -150,7 +155,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
TokenSource: tokenSource,
|
||||
MaxQueryResultRows: 50,
|
||||
MaxQueryResultRows: r.MaxQueryResultRows,
|
||||
ClientCreator: clientCreator,
|
||||
}
|
||||
|
||||
@@ -567,7 +572,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
|
||||
var val []bigqueryapi.Value
|
||||
err = it.Next(&val)
|
||||
if err == iterator.Done {
|
||||
|
||||
@@ -21,9 +21,12 @@ import (
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
@@ -154,6 +157,26 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with max query result rows example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: bigquery
|
||||
project: my-project
|
||||
location: us
|
||||
maxQueryResultRows: 10
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": bigquery.Config{
|
||||
Name: "my-instance",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "my-project",
|
||||
Location: "us",
|
||||
MaxQueryResultRows: 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -220,6 +243,59 @@ func TestFailParseFromYaml(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
ctx = util.WithUserAgent(ctx, "test-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg bigquery.Config
|
||||
want int
|
||||
}{
|
||||
{
|
||||
desc: "default value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-default",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
want: 50,
|
||||
},
|
||||
{
|
||||
desc: "configured value",
|
||||
cfg: bigquery.Config{
|
||||
Name: "test-configured",
|
||||
Kind: bigquery.SourceKind,
|
||||
Project: "test-project",
|
||||
UseClientOAuth: true,
|
||||
MaxQueryResultRows: 100,
|
||||
},
|
||||
want: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("Initialize failed: %v", err)
|
||||
}
|
||||
bqSrc, ok := src.(*bigquery.Source)
|
||||
if !ok {
|
||||
t.Fatalf("Expected *bigquery.Source, got %T", src)
|
||||
}
|
||||
if bqSrc.MaxQueryResultRows != tc.want {
|
||||
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -16,8 +16,12 @@ package cloudhealthcare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -255,3 +259,299 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func parseResults(resp *http.Response) (any, error) {
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
}
|
||||
|
||||
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
|
||||
svc := s.Service()
|
||||
var err error
|
||||
// Initialize new service if using user OAuth token
|
||||
if s.UseClientAuthorization() {
|
||||
svc, err = s.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
|
||||
var httpClient *http.Client
|
||||
if s.UseClientAuthorization() {
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
return dataset, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return parseResults(resp)
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(s.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(s.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||
return base64String, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||
svc, err := s.getService(tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||
var resp *http.Response
|
||||
switch toolKind {
|
||||
case "cloud-healthcare-search-dicom-instances":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-series":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
case "cloud-healthcare-search-dicom-studies":
|
||||
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
@@ -36,7 +37,10 @@ import (
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
var (
|
||||
targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
backupDRRegex = regexp.MustCompile(`^projects/([^/]+)/locations/([^/]+)/backupVaults/([^/]+)/dataSources/([^/]+)/backups/([^/]+)$`)
|
||||
)
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -352,6 +356,70 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
|
||||
backupRun := &sqladmin.BackupRun{}
|
||||
if location != "" {
|
||||
backupRun.Location = location
|
||||
}
|
||||
if backupDescription != "" {
|
||||
backupRun.Description = backupDescription
|
||||
}
|
||||
|
||||
service, err := s.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating backup: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *Source) RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error) {
|
||||
request := &sqladmin.InstancesRestoreBackupRequest{}
|
||||
|
||||
// There are 3 scenarios for the backup identifier:
|
||||
// 1. The identifier is an int64 containing the timestamp of the BackupRun.
|
||||
// This is used to restore standard backups, and the RestoreBackupContext
|
||||
// field should be populated with the backup ID and source instance info.
|
||||
// 2. The identifier is a string of the format
|
||||
// 'projects/{project-id}/locations/{location}/backupVaults/{backupvault}/dataSources/{datasource}/backups/{backup-uid}'.
|
||||
// This is used to restore BackupDR backups, and the BackupdrBackup field
|
||||
// should be populated.
|
||||
// 3. The identifer is a string of the format
|
||||
// 'projects/{project-id}/backups/{backup-uid}'. In this case, the Backup
|
||||
// field should be populated.
|
||||
if backupRunID, err := strconv.ParseInt(backupID, 10, 64); err == nil {
|
||||
if sourceProject == "" || targetInstance == "" {
|
||||
return nil, fmt.Errorf("source project and instance are required when restoring via backup ID")
|
||||
}
|
||||
request.RestoreBackupContext = &sqladmin.RestoreBackupContext{
|
||||
Project: sourceProject,
|
||||
InstanceId: sourceInstance,
|
||||
BackupRunId: backupRunID,
|
||||
}
|
||||
} else if backupDRRegex.MatchString(backupID) {
|
||||
request.BackupdrBackup = backupID
|
||||
} else {
|
||||
request.Backup = backupID
|
||||
}
|
||||
|
||||
service, err := s.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := service.Instances.RestoreBackup(targetProject, targetInstance, request).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error restoring backup: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
@@ -121,3 +123,101 @@ func initDataplexConnection(
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
result, err := s.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
|
||||
PageSize: int32(pageSize),
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := s.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
|
||||
}
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
|
||||
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
|
||||
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
|
||||
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
@@ -16,7 +16,10 @@ package firestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/firestore"
|
||||
"github.com/goccy/go-yaml"
|
||||
@@ -25,6 +28,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/genproto/googleapis/type/latlng"
|
||||
)
|
||||
|
||||
const SourceKind string = "firestore"
|
||||
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
|
||||
return s.Database
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value any) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]any{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []any:
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]any:
|
||||
result := make(map[string]any)
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// BuildQuery constructs the Firestore query from parameters
|
||||
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Process and apply filters if template is provided
|
||||
if filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
if field != "" {
|
||||
query = query.OrderBy(field, direction)
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if analyzeQuery {
|
||||
query = query.WithRunOptions(firestore.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime any `json:"createTime,omitempty"`
|
||||
UpdateTime any `json:"updateTime,omitempty"`
|
||||
ReadTime any `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// ExecuteQuery runs the query and formats the results
|
||||
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = s.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the collection reference
|
||||
collection := s.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
|
||||
// Get the document reference
|
||||
docRef := s.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestore.WriteResult
|
||||
var writeErr error
|
||||
|
||||
if len(updates) > 0 {
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := s.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
|
||||
var collectionRefs []*firestore.CollectionRef
|
||||
var err error
|
||||
if parentPath != "" {
|
||||
// List subcollections of the specified document
|
||||
docRef := s.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
results[i] = collData
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetRules(ctx context.Context) (any, error) {
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
|
||||
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
|
||||
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
sourceLines := strings.Split(sourceParam, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
formattedIssues := strings.Join(formattedOutput, "\n\n")
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func initFirestoreConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
|
||||
@@ -16,6 +16,7 @@ package firestore_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]any{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []any{"tag1", "tag2"},
|
||||
"metadata": map[string]any{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := firestore.FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,9 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
@@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string {
|
||||
func (s *Source) Client() *http.Client {
|
||||
return s.client
|
||||
}
|
||||
|
||||
func (s *Source) RunRequest(req *http.Request) (any, error) {
|
||||
// Make request and fetch response
|
||||
resp, err := s.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -15,7 +15,9 @@ package looker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -208,6 +210,49 @@ func (s *Source) LookerSessionLength() int64 {
|
||||
return s.SessionLength
|
||||
}
|
||||
|
||||
// Make types for RoundTripper
|
||||
type transportWithAuthHeader struct {
|
||||
Base http.RoundTripper
|
||||
AuthToken string
|
||||
}
|
||||
|
||||
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("x-looker-appid", "go-sdk")
|
||||
req.Header.Set("Authorization", t.AuthToken)
|
||||
return t.Base.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (s *Source) GetLookerSDK(accessToken string) (*v4.LookerSDK, error) {
|
||||
if s.UseClientAuthorization() {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token supplied with request")
|
||||
}
|
||||
|
||||
session := rtl.NewAuthSession(*s.LookerApiSettings())
|
||||
// Configure base transport with TLS
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: !s.LookerApiSettings().VerifySsl,
|
||||
},
|
||||
}
|
||||
|
||||
// Build transport for end user token
|
||||
session.Client = http.Client{
|
||||
Transport: &transportWithAuthHeader{
|
||||
Base: transport,
|
||||
AuthToken: accessToken,
|
||||
},
|
||||
}
|
||||
// return SDK with new Transport
|
||||
return v4.NewLookerSDK(session), nil
|
||||
}
|
||||
|
||||
if s.LookerClient() == nil {
|
||||
return nil, fmt.Errorf("client id or client secret not valid")
|
||||
}
|
||||
return s.LookerClient(), nil
|
||||
}
|
||||
|
||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||
if err != nil {
|
||||
|
||||
@@ -16,11 +16,14 @@ package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
@@ -93,6 +96,201 @@ func (s *Source) MongoClient() *mongo.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
|
||||
var data = []any{}
|
||||
err := cur.All(ctx, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var final []any
|
||||
for _, item := range data {
|
||||
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
}
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
|
||||
var pipeline = []bson.M{}
|
||||
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if readOnly {
|
||||
//fail if we do a merge or an out
|
||||
for _, stage := range pipeline {
|
||||
for key := range stage {
|
||||
if key == "$merge" || key == "$out" {
|
||||
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
res, err := parseData(ctx, cur)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if res == nil {
|
||||
return []any{}, nil
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cur.Close(ctx)
|
||||
return parseData(ctx, cur)
|
||||
}
|
||||
|
||||
func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
|
||||
if res.Err() != nil {
|
||||
return nil, res.Err()
|
||||
}
|
||||
|
||||
var data any
|
||||
err = res.Decode(&data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var final []any
|
||||
tmp, _ := bson.MarshalExtJSON(data, false, false)
|
||||
var tmp2 any
|
||||
err = json.Unmarshal(tmp, &tmp2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
final = append(final, tmp2)
|
||||
|
||||
return final, err
|
||||
}
|
||||
|
||||
func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
|
||||
var data = []any{}
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.InsertedIDs, nil
|
||||
}
|
||||
|
||||
func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
|
||||
var data any
|
||||
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.InsertedID, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
|
||||
}
|
||||
|
||||
func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||
}
|
||||
var update = bson.D{}
|
||||
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||
}
|
||||
return res.ModifiedCount, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.DeletedCount == 0 {
|
||||
return nil, errors.New("no document found")
|
||||
}
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||
var filter = bson.D{}
|
||||
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.DeletedCount, nil
|
||||
}
|
||||
|
||||
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
||||
// Start a tracing span
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
|
||||
@@ -16,15 +16,21 @@ package serverlessspark
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/iterator"
|
||||
"google.golang.org/api/option"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const SourceKind string = "serverless-spark"
|
||||
@@ -121,3 +127,168 @@ func (s *Source) Close() error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
||||
req := &longrunningpb.CancelOperationRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
||||
}
|
||||
client, err := s.GetOperationsClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||
}
|
||||
err = client.CancelOperation(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||
}
|
||||
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||
}
|
||||
|
||||
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
||||
req := &dataprocpb.CreateBatchRequest{
|
||||
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
||||
Batch: batch,
|
||||
}
|
||||
|
||||
client := s.GetBatchControllerClient()
|
||||
op, err := client.CreateBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||
}
|
||||
meta, err := op.Metadata()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||
}
|
||||
|
||||
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||
}
|
||||
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
||||
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"opMetadata": meta,
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
}
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
// ListBatchesResponse is the response from the list batches API.
|
||||
type ListBatchesResponse struct {
|
||||
Batches []Batch `json:"batches"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// Batch represents a single batch job.
|
||||
type Batch struct {
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
State string `json:"state"`
|
||||
Creator string `json:"creator"`
|
||||
CreateTime string `json:"createTime"`
|
||||
Operation string `json:"operation"`
|
||||
ConsoleURL string `json:"consoleUrl"`
|
||||
LogsURL string `json:"logsUrl"`
|
||||
}
|
||||
|
||||
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
||||
req := &dataprocpb.ListBatchesRequest{
|
||||
Parent: parent,
|
||||
OrderBy: "create_time desc",
|
||||
}
|
||||
|
||||
if ps != nil {
|
||||
req.PageSize = int32(*ps)
|
||||
}
|
||||
if pt != "" {
|
||||
req.PageToken = pt
|
||||
}
|
||||
if filter != "" {
|
||||
req.Filter = filter
|
||||
}
|
||||
|
||||
it := client.ListBatches(ctx, req)
|
||||
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||
|
||||
var batchPbs []*dataprocpb.Batch
|
||||
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||
}
|
||||
|
||||
batches, err := ToBatches(batchPbs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||
}
|
||||
|
||||
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||
batches := make([]Batch, 0, len(batchPbs))
|
||||
for _, batchPb := range batchPbs {
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
batch := Batch{
|
||||
Name: batchPb.Name,
|
||||
UUID: batchPb.Uuid,
|
||||
State: batchPb.State.Enum().String(),
|
||||
Creator: batchPb.Creator,
|
||||
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||
Operation: batchPb.Operation,
|
||||
ConsoleURL: consoleUrl,
|
||||
LogsURL: logsUrl,
|
||||
}
|
||||
batches = append(batches, batch)
|
||||
}
|
||||
return batches, nil
|
||||
}
|
||||
|
||||
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
||||
client := s.GetBatchControllerClient()
|
||||
req := &dataprocpb.GetBatchRequest{
|
||||
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
||||
}
|
||||
|
||||
batchPb, err := client.GetBatch(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||
}
|
||||
|
||||
jsonBytes, err := protojson.Marshal(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||
}
|
||||
|
||||
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||
}
|
||||
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||
}
|
||||
|
||||
wrappedResult := map[string]any{
|
||||
"consoleUrl": consoleUrl,
|
||||
"logsUrl": logsUrl,
|
||||
"batch": result,
|
||||
}
|
||||
|
||||
return wrappedResult, nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
const (
|
||||
logTimeBufferBefore = 1 * time.Minute
|
||||
logTimeBufferAfter = 10 * time.Minute
|
||||
)
|
||||
|
||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||
|
||||
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
||||
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
||||
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
||||
@@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
|
||||
return matches[1], matches[2], matches[3], nil
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
|
||||
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURL(projectID, location, batchID string) string {
|
||||
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
||||
@@ -89,3 +69,23 @@ resource.labels.batch_id="%s"`
|
||||
|
||||
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
||||
}
|
||||
|
||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return BatchConsoleURL(projectID, location, batchID), nil
|
||||
}
|
||||
|
||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
createTime := batchPb.GetCreateTime().AsTime()
|
||||
stateTime := batchPb.GetStateTime().AsTime()
|
||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
// Copyright 2025 Google LLC
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@@ -12,19 +12,20 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package common
|
||||
package serverlessspark_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
||||
projectID, location, batchID, err := ExtractBatchDetails(batchName)
|
||||
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
if err != nil {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
||||
return
|
||||
@@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
|
||||
|
||||
func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
batchName := "invalid-name"
|
||||
_, _, _, err := ExtractBatchDetails(batchName)
|
||||
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||
wantErr := "failed to parse batch name: invalid-name"
|
||||
if err == nil || err.Error() != wantErr {
|
||||
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
||||
@@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBatchConsoleURL(t *testing.T) {
|
||||
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
||||
if got != want {
|
||||
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
||||
@@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) {
|
||||
func TestBatchLogsURL(t *testing.T) {
|
||||
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
||||
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
||||
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
||||
"resource.type%3D%22cloud_dataproc_batch%22" +
|
||||
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
||||
@@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
|
||||
batchPb := &dataprocpb.Batch{
|
||||
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
||||
}
|
||||
got, err := BatchConsoleURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
|
||||
CreateTime: timestamppb.New(createTime),
|
||||
StateTime: timestamppb.New(stateTime),
|
||||
}
|
||||
got, err := BatchLogsURLFromProto(batchPb)
|
||||
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
|
||||
if err != nil {
|
||||
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
||||
}
|
||||
@@ -28,6 +28,21 @@ import (
|
||||
|
||||
const kind string = "cloud-gemini-data-analytics-query"
|
||||
|
||||
// Guidance is the tool guidance string.
|
||||
const Guidance = `Tool guidance:
|
||||
Inputs:
|
||||
1. query: A natural language formulation of a database query.
|
||||
Outputs: (all optional)
|
||||
1. disambiguation_question: Clarification questions or comments where the tool needs the users' input.
|
||||
2. generated_query: The generated query for the user query.
|
||||
3. intent_explanation: An explanation for why the tool produced ` + "`generated_query`" + `.
|
||||
4. query_result: The result of executing ` + "`generated_query`" + `.
|
||||
5. natural_language_answer: The natural language answer that summarizes the ` + "`query`" + ` and ` + "`query_result`" + `.
|
||||
|
||||
Usage guidance:
|
||||
1. If ` + "`disambiguation_question`" + ` is produced, then solicit the needed inputs from the user and try the tool with a new ` + "`query`" + ` that has the needed clarification.
|
||||
2. If ` + "`natural_language_answer`" + ` is produced, use ` + "`intent_explanation`" + ` and ` + "`generated_query`" + ` to see if you need to clarify any assumptions for the user.`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
@@ -68,11 +83,18 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
// The query is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
|
||||
parameters.NewStringParameterWithRequired("query", "A natural language formulation of a database query.", true),
|
||||
}
|
||||
// The input and outputs are for tool guidance, usage guidance is for multi-turn interaction.
|
||||
guidance := Guidance
|
||||
|
||||
if cfg.Description != "" {
|
||||
cfg.Description += "\n\n" + guidance
|
||||
} else {
|
||||
cfg.Description = guidance
|
||||
}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
@@ -105,9 +127,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
query, ok := paramsMap["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("prompt parameter not found or not a string")
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
}
|
||||
|
||||
// Parse the access token if provided
|
||||
@@ -125,7 +147,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
Prompt: prompt,
|
||||
Prompt: query,
|
||||
Context: t.Context,
|
||||
GenerationOptions: t.GenerationOptions,
|
||||
}
|
||||
|
||||
@@ -328,9 +328,9 @@ func TestInvoke(t *testing.T) {
|
||||
t.Fatalf("failed to initialize tool: %v", err)
|
||||
}
|
||||
|
||||
// Prepare parameters for invocation - ONLY prompt
|
||||
// Prepare parameters for invocation - ONLY query
|
||||
params := parameters.ParamValues{
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
@@ -16,22 +16,13 @@ package fhirfetchpage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-fetch-page"
|
||||
@@ -54,13 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRFetchPage(context.Context, string, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,49 +103,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||
httpClient = oauth2.NewClient(ctx, ts)
|
||||
} else {
|
||||
// The source.Service() object holds a client with the default credentials.
|
||||
// However, the client is not exported, so we have to create a new one.
|
||||
var err error
|
||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package fhirpatienteverything
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-patient-everything"
|
||||
@@ -54,13 +50,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRPatientEverything(string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -139,20 +131,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
||||
var opts []googleapi.CallOption
|
||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||
types, ok := val.([]any)
|
||||
@@ -176,25 +162,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("patient-everything: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package fhirpatientsearch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-fhir-patient-search"
|
||||
@@ -70,13 +66,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
FHIRPatientSearch(string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -169,17 +161,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var summary bool
|
||||
@@ -248,26 +235,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if summary {
|
||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.FHIRPatientSearch(storeID, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -44,12 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDataset(string) (*healthcare.Dataset, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -100,27 +95,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
return dataset, nil
|
||||
return source.GetDataset(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDICOMStore(string, string) (*healthcare.DicomStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetDICOMStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetDICOMStoreMetrics(string, string) (*healthcare.DicomStoreMetrics, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,14 @@ package getfhirresource
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-get-fhir-resource"
|
||||
@@ -51,13 +47,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRResource(string, string, string, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -134,46 +126,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
|
||||
}
|
||||
|
||||
resID, ok := params.AsMap()[idKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
var jsonMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||
}
|
||||
return jsonMap, nil
|
||||
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRStore(string, string) (*healthcare.FhirStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetFHIRStore(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
GetFHIRStoreMetrics(string, string) (*healthcare.FhirStoreMetrics, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,31 +112,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||
}
|
||||
return store, nil
|
||||
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,10 @@ package listdicomstores
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,41 +95,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.DicomStore
|
||||
for _, store := range stores.DicomStores {
|
||||
if len(source.AllowedDICOMStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
return source.ListDICOMStores(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -17,12 +17,10 @@ package listfhirstores
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedFHIRStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
ListFHIRStores(string) ([]*healthcare.FhirStore, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,41 +95,14 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||
}
|
||||
var filtered []*healthcare.FhirStore
|
||||
for _, store := range stores.FhirStores {
|
||||
if len(source.AllowedFHIRStores()) == 0 {
|
||||
filtered = append(filtered, store)
|
||||
continue
|
||||
}
|
||||
if len(store.Name) == 0 {
|
||||
continue
|
||||
}
|
||||
parts := strings.Split(store.Name, "/")
|
||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||
filtered = append(filtered, store)
|
||||
}
|
||||
}
|
||||
return filtered, nil
|
||||
return source.ListFHIRStores(tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,14 @@ package retrieverendereddicominstance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
|
||||
@@ -53,13 +49,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
RetrieveRenderedDICOMInstance(string, string, string, string, int, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -135,20 +127,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
|
||||
@@ -165,25 +150,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||
call.Header().Set("Accept", "image/jpeg")
|
||||
resp, err := call.Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||
return base64String, nil
|
||||
return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,20 +16,16 @@ package searchdicominstances
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-instances"
|
||||
@@ -60,13 +56,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -144,23 +136,16 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
@@ -191,29 +176,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,15 @@ package searchdicomseries
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-series"
|
||||
@@ -57,13 +54,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -145,18 +138,12 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||
@@ -174,29 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
|
||||
}
|
||||
}
|
||||
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,18 +16,15 @@ package searchdicomstudies
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/healthcare/v1"
|
||||
"google.golang.org/api/googleapi"
|
||||
)
|
||||
|
||||
const kind string = "cloud-healthcare-search-dicom-studies"
|
||||
@@ -55,13 +52,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
Project() string
|
||||
Region() string
|
||||
DatasetID() string
|
||||
AllowedDICOMStores() map[string]struct{}
|
||||
Service() *healthcare.Service
|
||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
||||
UseClientAuthorization() bool
|
||||
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -136,51 +129,23 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := source.Service()
|
||||
// Initialize new service if using user OAuth token
|
||||
var tokenStr string
|
||||
if source.UseClientAuthorization() {
|
||||
tokenStr, err := accessToken.ParseBearerToken()
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
svc, err = source.ServiceCreator()(tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||
}
|
||||
if len(respBytes) == 0 {
|
||||
return []interface{}{}, nil
|
||||
}
|
||||
var result []interface{}
|
||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
dicomWebPath := "studies"
|
||||
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlcreatebackup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-create-backup"
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error)
|
||||
}
|
||||
|
||||
// Config defines the configuration for the create-backup tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind returns the kind of the tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.GetDefaultProject()
|
||||
var projectParam parameters.Parameter
|
||||
if project != "" {
|
||||
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
} else {
|
||||
projectParam = parameters.NewStringParameter("project", "The project ID")
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
projectParam,
|
||||
parameters.NewStringParameter("instance", "Cloud SQL instance ID. This does not include the project ID."),
|
||||
// Location and backup_description are optional.
|
||||
parameters.NewStringParameterWithRequired("location", "Location of the backup run.", false),
|
||||
parameters.NewStringParameterWithRequired("backup_description", "The description of this backup run.", false),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
description := cfg.Description
|
||||
if description == "" {
|
||||
description = "Creates a backup on a Cloud SQL instance."
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the create-backup tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"])
|
||||
}
|
||||
instance, ok := paramsMap["instance"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error casting 'instance' parameter: %v", paramsMap["instance"])
|
||||
}
|
||||
|
||||
location, _ := paramsMap["location"].(string)
|
||||
description, _ := paramsMap["backup_description"].(string)
|
||||
|
||||
return source.InsertBackupRun(ctx, project, instance, location, description, string(accessToken))
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
||||
}
|
||||
|
||||
// Manifest returns the tool's manifest.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the tool's MCP manifest.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlcreatebackup_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
create-backup-tool:
|
||||
kind: cloud-sql-create-backup
|
||||
description: a test description
|
||||
source: a-source
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-backup-tool": cloudsqlcreatebackup.Config{
|
||||
Name: "create-backup-tool",
|
||||
Kind: "cloud-sql-create-backup",
|
||||
Description: "a test description",
|
||||
Source: "a-source",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,183 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlrestorebackup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-restore-backup"
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
type compatibleSource interface {
|
||||
GetDefaultProject() string
|
||||
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||
UseClientAuthorization() bool
|
||||
RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error)
|
||||
}
|
||||
|
||||
// Config defines the configuration for the restore-backup tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind returns the kind of the tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||
}
|
||||
|
||||
project := s.GetDefaultProject()
|
||||
var targetProjectParam parameters.Parameter
|
||||
if project != "" {
|
||||
targetProjectParam = parameters.NewStringParameterWithDefault("target_project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||
} else {
|
||||
targetProjectParam = parameters.NewStringParameter("target_project", "The project ID")
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
targetProjectParam,
|
||||
parameters.NewStringParameter("target_instance", "Cloud SQL instance ID of the target instance. This does not include the project ID."),
|
||||
parameters.NewStringParameter("backup_id", "Identifier of the backup being restored. Can be a BackupRun ID, backup name, or BackupDR backup name. Use the full backup ID as provided, do not try to parse it"),
|
||||
parameters.NewStringParameterWithRequired("source_project", "GCP project ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||
parameters.NewStringParameterWithRequired("source_instance", "Cloud SQL instance ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
description := cfg.Description
|
||||
if description == "" {
|
||||
description = "Restores a backup on a Cloud SQL instance."
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the restore-backup tool.
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters `yaml:"allParams"`
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
targetProject, ok := paramsMap["target_project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error casting 'target_project' parameter: %v", paramsMap["target_project"])
|
||||
}
|
||||
targetInstance, ok := paramsMap["target_instance"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"])
|
||||
}
|
||||
backupID, ok := paramsMap["backup_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"])
|
||||
}
|
||||
sourceProject, _ := paramsMap["source_project"].(string)
|
||||
sourceInstance, _ := paramsMap["source_instance"].(string)
|
||||
|
||||
return source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken))
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
||||
}
|
||||
|
||||
// Manifest returns the tool's manifest.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the tool's MCP manifest.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return source.UseClientAuthorization(), nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlrestorebackup_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
restore-backup-tool:
|
||||
kind: cloud-sql-restore-backup
|
||||
description: a test description
|
||||
source: a-source
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"restore-backup-tool": cloudsqlrestorebackup.Config{
|
||||
Name: "restore-backup-tool",
|
||||
Kind: "cloud-sql-restore-backup",
|
||||
Description: "a test description",
|
||||
Source: "a-source",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -118,12 +117,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
name, _ := paramsMap["name"].(string)
|
||||
entry, _ := paramsMap["entry"].(string)
|
||||
view, _ := paramsMap["view"].(int)
|
||||
@@ -132,19 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
||||
}
|
||||
aspectTypes := aspectTypeSlice.([]string)
|
||||
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,9 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -45,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -101,61 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Invoke the tool with the provided parameters
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
|
||||
// Create SearchEntriesRequest with the provided parameters
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
// Perform the search using the CatalogClient - this will return an iterator
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||
|
||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||
var results []*dataplexpb.AspectType
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||
Name: resourceName,
|
||||
}
|
||||
|
||||
operation := func() (*dataplexpb.AspectType, error) {
|
||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||
}
|
||||
return aspectType, nil
|
||||
}
|
||||
|
||||
// Retry the GetAspectType operation with exponential backoff
|
||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||
}
|
||||
|
||||
results = append(results, aspectType)
|
||||
}
|
||||
return results, nil
|
||||
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -18,8 +18,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -44,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
ProjectID() string
|
||||
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -100,34 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
pageSize := int32(paramsMap["pageSize"].(int))
|
||||
pageSize, _ := paramsMap["pageSize"].(int)
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
||||
PageSize: pageSize,
|
||||
OrderBy: orderBy,
|
||||
SemanticSearch: true,
|
||||
}
|
||||
|
||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
||||
if it == nil {
|
||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
||||
}
|
||||
|
||||
var results []*dataplexpb.SearchEntriesResult
|
||||
for {
|
||||
entry, err := it.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
return results, nil
|
||||
return source.SearchEntries(ctx, query, pageSize, orderBy)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -48,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -134,24 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
collectionPath, ok := mapParams[collectionPathKey].(string)
|
||||
if !ok || collectionPath == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
|
||||
}
|
||||
|
||||
// Validate collection path
|
||||
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid collection path: %w", err)
|
||||
}
|
||||
|
||||
// Get document data
|
||||
documentDataRaw, ok := mapParams[documentDataKey]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
|
||||
}
|
||||
|
||||
// Convert the document data from JSON format to Firestore format
|
||||
// The client is passed to handle referenceValue types
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -164,30 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the collection reference
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
|
||||
// Add the document to the collection
|
||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Convert the document data back to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(documentData)
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
DeleteDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -104,7 +105,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||
}
|
||||
|
||||
if len(documentPathsRaw) == 0 {
|
||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||
}
|
||||
@@ -126,45 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := source.FirestoreClient().Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.DeleteDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
GetDocuments(context.Context, []string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -126,37 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.GetDocuments(ctx, documentPaths)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -44,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
GetDatabaseId() string
|
||||
GetRules(context.Context) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -98,29 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
return source.GetRules(ctx)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
ListCollections(context.Context, string) ([]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -102,47 +103,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
|
||||
if hasParent && parentPath != "" {
|
||||
parentPath, _ := mapParams[parentPathKey].(string)
|
||||
if parentPath != "" {
|
||||
// Validate parent document path
|
||||
if err := util.ValidateDocumentPath(parentPath); err != nil {
|
||||
return nil, fmt.Errorf("invalid parent document path: %w", err)
|
||||
}
|
||||
|
||||
// List subcollections of the specified document
|
||||
docRef := source.FirestoreClient().Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
|
||||
results[i] = collData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
return source.ListCollections(ctx, parentPath)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -36,27 +36,6 @@ const (
|
||||
defaultLimit = 100
|
||||
)
|
||||
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errFilterParseFailed = "failed to parse filters: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
@@ -74,6 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query tool
|
||||
@@ -139,15 +120,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// OrderByConfig represents ordering configuration
|
||||
type OrderByConfig struct {
|
||||
Field string `json:"field"`
|
||||
@@ -162,20 +134,27 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
// SimplifiedFilter represents the simplified filter format
|
||||
type SimplifiedFilter struct {
|
||||
And []SimplifiedFilter `json:"and,omitempty"`
|
||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||
Field string `json:"field,omitempty"`
|
||||
Op string `json:"op,omitempty"`
|
||||
Value interface{} `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
@@ -184,34 +163,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Process collection path with template substitution
|
||||
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process collection path: %w", err)
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query)
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(collectionPath)
|
||||
query := collection.Query
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Process and apply filters if template is provided
|
||||
if t.Filters != "" {
|
||||
// Apply template substitution to filters
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
|
||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process filters template: %w", err)
|
||||
}
|
||||
@@ -219,48 +182,43 @@ func (t Tool) buildQuery(source compatibleSource, collectionPath string, params
|
||||
// Parse the simplified filter format
|
||||
var simplifiedFilter SimplifiedFilter
|
||||
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
|
||||
return nil, fmt.Errorf(errFilterParseFailed, err)
|
||||
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||
}
|
||||
|
||||
// Convert simplified filter to Firestore filter
|
||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
||||
query = query.WhereEntity(filter)
|
||||
}
|
||||
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
|
||||
}
|
||||
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(selectFields) > 0 {
|
||||
query = query.Select(selectFields...)
|
||||
}
|
||||
|
||||
// Process and apply ordering
|
||||
orderBy, err := t.getOrderBy(params)
|
||||
orderBy, err := t.getOrderBy(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if orderBy != nil {
|
||||
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
|
||||
// Process select fields
|
||||
selectFields, err := t.processSelectFields(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Process and apply limit
|
||||
limit, err := t.getLimit(params)
|
||||
limit, err := t.getLimit(paramsMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query = query.Limit(limit)
|
||||
|
||||
// Apply analyze options if enabled
|
||||
if t.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
// prevent panic when accessing orderBy incase it is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if orderBy != nil {
|
||||
orderByField = orderBy.Field
|
||||
orderByDirection = orderBy.GetDirection()
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
// Build the query
|
||||
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Execute the query and return results
|
||||
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||
@@ -409,7 +367,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
if processedValue != "" {
|
||||
parsedLimit, err := strconv.Atoi(processedValue)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
|
||||
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
|
||||
}
|
||||
limit = parsedLimit
|
||||
}
|
||||
@@ -417,78 +375,6 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if t.AnalyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -69,7 +69,6 @@ const (
|
||||
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
||||
errMissingFilterValue = "no value specified for filter on field '%s'"
|
||||
errOrderByParseFailed = "failed to parse orderBy: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
||||
)
|
||||
|
||||
@@ -90,6 +89,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||
}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
@@ -228,22 +229,6 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
@@ -257,14 +242,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var filter firestoreapi.EntityFilter
|
||||
// Apply filters
|
||||
if len(queryParams.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
|
||||
for _, filter := range queryParams.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
filter = firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
}
|
||||
}
|
||||
|
||||
// prevent panic incase queryParams.OrderBy is nil
|
||||
var orderByField string
|
||||
var orderByDirection firestoreapi.Direction
|
||||
if queryParams.OrderBy != nil {
|
||||
orderByField = queryParams.OrderBy.Field
|
||||
orderByDirection = queryParams.OrderBy.GetDirection()
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(source, queryParams)
|
||||
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// queryParameters holds all parsed query parameters
|
||||
@@ -380,122 +388,6 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
return &orderBy, nil
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
if len(params.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
||||
for _, filter := range params.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
query = query.WhereEntity(firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply ordering
|
||||
if params.OrderBy != nil {
|
||||
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
query = query.Limit(params.Limit)
|
||||
|
||||
// Apply analyze options
|
||||
if params.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return just the documents
|
||||
resultsAny := make([]any, len(results))
|
||||
for i, r := range results {
|
||||
resultsAny[i] = r
|
||||
}
|
||||
return resultsAny, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.Parameters, data, claims)
|
||||
|
||||
@@ -50,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -177,23 +178,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Get the document reference
|
||||
docRef := source.FirestoreClient().Doc(documentPath)
|
||||
|
||||
// Prepare update data
|
||||
var writeResult *firestoreapi.WriteResult
|
||||
var writeErr error
|
||||
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
var documentData any
|
||||
if len(updatePaths) > 0 {
|
||||
// Use selective field update with update mask
|
||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||
|
||||
// Convert document data without delete markers
|
||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
@@ -220,41 +208,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||
} else {
|
||||
// Update all fields in the document data (merge)
|
||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||
}
|
||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||
// Get return document data flag
|
||||
returnData := false
|
||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||
returnData = val
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response := map[string]any{
|
||||
"documentPath": docRef.Path,
|
||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||
}
|
||||
|
||||
// Add document data if requested
|
||||
if returnData {
|
||||
// Fetch the updated document to return the current state
|
||||
snapshot, err := docRef.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||
}
|
||||
|
||||
// Convert the document data to simple JSON format
|
||||
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
|
||||
response["documentData"] = simplifiedData
|
||||
}
|
||||
|
||||
return response, nil
|
||||
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
|
||||
}
|
||||
|
||||
// getFieldValue retrieves a value from a nested map using a dot-separated path
|
||||
|
||||
@@ -17,7 +17,6 @@ package firestorevalidaterules
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
@@ -50,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
ValidateRules(context.Context, string) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -107,30 +106,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||
if err != nil {
|
||||
@@ -144,114 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if !ok || sourceParam == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: sourceParam,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, sourceParam)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
formattedIssues := t.formatRulesetIssues(issues, source)
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}
|
||||
}
|
||||
|
||||
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
|
||||
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
|
||||
sourceLines := strings.Split(rulesSource, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
return strings.Join(formattedOutput, "\n\n")
|
||||
return source.ValidateRules(ctx, sourceParam)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -28,13 +28,13 @@ import (
|
||||
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
|
||||
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
|
||||
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
|
||||
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
|
||||
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
// Check for typed values
|
||||
if len(v) == 1 {
|
||||
for key, val := range v {
|
||||
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("timestamp value must be a string")
|
||||
case "geoPointValue":
|
||||
// Convert to LatLng
|
||||
if geoMap, ok := val.(map[string]interface{}); ok {
|
||||
if geoMap, ok := val.(map[string]any); ok {
|
||||
lat, latOk := geoMap["latitude"].(float64)
|
||||
lng, lngOk := geoMap["longitude"].(float64)
|
||||
if latOk && lngOk {
|
||||
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid geopoint value format")
|
||||
case "arrayValue":
|
||||
// Convert array
|
||||
if arrayMap, ok := val.(map[string]interface{}); ok {
|
||||
if values, ok := arrayMap["values"].([]interface{}); ok {
|
||||
result := make([]interface{}, len(values))
|
||||
if arrayMap, ok := val.(map[string]any); ok {
|
||||
if values, ok := arrayMap["values"].([]any); ok {
|
||||
result := make([]any, len(values))
|
||||
for i, item := range values {
|
||||
converted, err := JSONToFirestoreValue(item, client)
|
||||
if err != nil {
|
||||
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
return nil, fmt.Errorf("invalid array value format")
|
||||
case "mapValue":
|
||||
// Convert map
|
||||
if mapMap, ok := val.(map[string]interface{}); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
|
||||
result := make(map[string]interface{})
|
||||
if mapMap, ok := val.(map[string]any); ok {
|
||||
if fields, ok := mapMap["fields"].(map[string]any); ok {
|
||||
result := make(map[string]any)
|
||||
for k, v := range fields {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
||||
}
|
||||
|
||||
// convertPlainMap converts a plain map to Firestore format
|
||||
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
|
||||
result := make(map[string]any)
|
||||
for k, v := range m {
|
||||
converted, err := JSONToFirestoreValue(v, client)
|
||||
if err != nil {
|
||||
@@ -172,42 +172,6 @@ func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[st
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||
// This removes type information and returns plain values
|
||||
func FirestoreValueToJSON(value interface{}) interface{} {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case time.Time:
|
||||
return v.Format(time.RFC3339Nano)
|
||||
case *latlng.LatLng:
|
||||
return map[string]interface{}{
|
||||
"latitude": v.Latitude,
|
||||
"longitude": v.Longitude,
|
||||
}
|
||||
case []byte:
|
||||
return base64.StdEncoding.EncodeToString(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = FirestoreValueToJSON(item)
|
||||
}
|
||||
return result
|
||||
case map[string]interface{}:
|
||||
result := make(map[string]interface{})
|
||||
for k, val := range v {
|
||||
result[k] = FirestoreValueToJSON(val)
|
||||
}
|
||||
return result
|
||||
case *firestore.DocumentRef:
|
||||
return v.Path
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// isValidDocumentPath checks if a string is a valid Firestore document path
|
||||
// Valid paths have an even number of segments (collection/doc/collection/doc...)
|
||||
func isValidDocumentPath(path string) bool {
|
||||
|
||||
@@ -312,40 +312,6 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||
// Test round-trip conversion
|
||||
original := map[string]interface{}{
|
||||
"name": "Test",
|
||||
"count": int64(42),
|
||||
"price": 19.99,
|
||||
"active": true,
|
||||
"tags": []interface{}{"tag1", "tag2"},
|
||||
"metadata": map[string]interface{}{
|
||||
"created": time.Now(),
|
||||
},
|
||||
"nullField": nil,
|
||||
}
|
||||
|
||||
// Convert to JSON representation
|
||||
jsonRepresentation := FirestoreValueToJSON(original)
|
||||
|
||||
// Verify types are simplified
|
||||
jsonMap, ok := jsonRepresentation.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||
}
|
||||
|
||||
// Time should be converted to string
|
||||
metadata, ok := jsonMap["metadata"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||
}
|
||||
_, ok = metadata["created"].(string)
|
||||
if !ok {
|
||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -16,9 +16,7 @@ package http
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
@@ -54,7 +52,7 @@ type compatibleSource interface {
|
||||
HttpDefaultHeaders() map[string]string
|
||||
HttpBaseURL() string
|
||||
HttpQueryParams() map[string]string
|
||||
Client() *http.Client
|
||||
RunRequest(*http.Request) (any, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -259,29 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
for k, v := range allHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
// Make request and fetch response
|
||||
resp, err := source.Client().Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var body []byte
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return string(body), nil
|
||||
}
|
||||
return data, nil
|
||||
return source.RunRequest(req)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -159,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
visConfig := paramsMap["vis_config"].(map[string]any)
|
||||
wq.VisConfig = &visConfig
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -192,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
req.Dimension = &dimension
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -15,65 +15,17 @@ package lookercommon
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
rtl "github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
||||
"github.com/thlib/go-timezone-local/tzlocal"
|
||||
)
|
||||
|
||||
// Make types for RoundTripper
|
||||
type transportWithAuthHeader struct {
|
||||
Base http.RoundTripper
|
||||
AuthToken tools.AccessToken
|
||||
}
|
||||
|
||||
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.Header.Set("x-looker-appid", "go-sdk")
|
||||
req.Header.Set("Authorization", string(t.AuthToken))
|
||||
return t.Base.RoundTrip(req)
|
||||
}
|
||||
|
||||
func GetLookerSDK(useClientOAuth bool, config *rtl.ApiSettings, client *v4.LookerSDK, accessToken tools.AccessToken) (*v4.LookerSDK, error) {
|
||||
|
||||
if useClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("no access token supplied with request")
|
||||
}
|
||||
|
||||
session := rtl.NewAuthSession(*config)
|
||||
// Configure base transport with TLS
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: !config.VerifySsl,
|
||||
},
|
||||
}
|
||||
|
||||
// Build transport for end user token
|
||||
session.Client = http.Client{
|
||||
Transport: &transportWithAuthHeader{
|
||||
Base: transport,
|
||||
AuthToken: accessToken,
|
||||
},
|
||||
}
|
||||
|
||||
// return SDK with new Transport
|
||||
return v4.NewLookerSDK(session), nil
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("client id or client secret not valid")
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
const (
|
||||
DimensionsFields = "fields(dimensions(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
||||
FiltersFields = "fields(filters(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
||||
|
||||
@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -116,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -117,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -125,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -49,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerSessionLength() int64
|
||||
}
|
||||
|
||||
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
contentId_ptr = nil
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -122,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
db, _ := mapParams["db"].(string)
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -132,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
limit := int64(paramsMap["limit"].(int))
|
||||
offset := int64(paramsMap["offset"].(int))
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenFields() bool
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("error processing model or explore: %w", err)
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenExplores() bool
|
||||
}
|
||||
|
||||
@@ -126,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"])
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenFields() bool
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
fields := lookercommon.FiltersFields
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
limit := int64(paramsMap["limit"].(int))
|
||||
offset := int64(paramsMap["offset"].(int))
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenFields() bool
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
fields := lookercommon.MeasuresFields
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenModels() bool
|
||||
}
|
||||
|
||||
@@ -124,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
excludeHidden := !source.LookerShowHiddenModels()
|
||||
includeInternal := true
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
LookerShowHiddenFields() bool
|
||||
}
|
||||
|
||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
fields := lookercommon.ParametersFields
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -121,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
|
||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
UseClientAuthorization() bool
|
||||
GetAuthTokenHeaderName() string
|
||||
LookerClient() *v4.LookerSDK
|
||||
LookerApiSettings() *rtl.ApiSettings
|
||||
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
||||
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user