mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-17 03:18:00 -05:00
Compare commits
19 Commits
spanner-cr
...
dgraph-doc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a989bab10 | ||
|
|
00c3e6d8cb | ||
|
|
d00b6fdf18 | ||
|
|
4d23a3bbf2 | ||
|
|
5e0999ebf5 | ||
|
|
6b02591703 | ||
|
|
8e0fb03483 | ||
|
|
2817dd1e5d | ||
|
|
68a218407e | ||
|
|
d69792d843 | ||
|
|
647b04d3a7 | ||
|
|
030df9766f | ||
|
|
5dbf207162 | ||
|
|
9c3720e31d | ||
|
|
3cd3c39d66 | ||
|
|
0691a6f715 | ||
|
|
467b96a23b | ||
|
|
4abf0c39e7 | ||
|
|
dd7b9de623 |
@@ -59,6 +59,13 @@ You can manually trigger the bot by commenting on your Pull Request:
|
|||||||
* `/gemini summary`: Posts a summary of the changes in the pull request.
|
* `/gemini summary`: Posts a summary of the changes in the pull request.
|
||||||
* `/gemini help`: Overview of the available commands
|
* `/gemini help`: Overview of the available commands
|
||||||
|
|
||||||
|
## Guidelines for Pull Requests
|
||||||
|
|
||||||
|
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
|
||||||
|
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
|
||||||
|
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
|
||||||
|
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
|
||||||
|
|
||||||
## Adding a New Database Source or Tool
|
## Adding a New Database Source or Tool
|
||||||
|
|
||||||
Please create an
|
Please create an
|
||||||
@@ -110,6 +117,8 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
|
|||||||
We recommend looking at an [example tool
|
We recommend looking at an [example tool
|
||||||
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
|
||||||
|
|
||||||
|
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
|
||||||
|
|
||||||
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
|
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
|
||||||
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
|
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
|
||||||
Create a `Config` struct and a `Tool` struct to store necessary parameters for
|
Create a `Config` struct and a `Tool` struct to store necessary parameters for
|
||||||
@@ -163,6 +172,8 @@ tools.
|
|||||||
parameters][temp-param-doc]. Only run this test if template
|
parameters][temp-param-doc]. Only run this test if template
|
||||||
parameters apply to your tool.
|
parameters apply to your tool.
|
||||||
|
|
||||||
|
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
|
||||||
|
|
||||||
* **Add the new database to the integration test workflow** in
|
* **Add the new database to the integration test workflow** in
|
||||||
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
||||||
|
|
||||||
@@ -244,4 +255,4 @@ resources.
|
|||||||
* **PR Description:** PR description should **always** be included. It should
|
* **PR Description:** PR description should **always** be included. It should
|
||||||
include a concise description of the changes, it's impact, along with a
|
include a concise description of the changes, it's impact, along with a
|
||||||
summary of the solution. If the PR is related to a specific issue, the issue
|
summary of the solution. If the PR is related to a specific issue, the issue
|
||||||
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
number should be mentioned in the PR description (e.g. `Fixes #1`).
|
||||||
|
|||||||
17
DEVELOPER.md
17
DEVELOPER.md
@@ -379,6 +379,23 @@ to approve PRs for main. TeamSync is used to create this team from the MDB
|
|||||||
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
|
Group `toolbox-contributors`. Googlers who are developing for MCP-Toolbox
|
||||||
but aren't part of the core team should join this group.
|
but aren't part of the core team should join this group.
|
||||||
|
|
||||||
|
### Issue/PR Triage and SLO
|
||||||
|
After an issue is created, maintainers will assign the following labels:
|
||||||
|
* `Priority` (defaulted to P0)
|
||||||
|
* `Type` (if applicable)
|
||||||
|
* `Product` (if applicable)
|
||||||
|
|
||||||
|
All incoming issues and PRs will follow the following SLO:
|
||||||
|
| Type | Priority | Objective |
|
||||||
|
|-----------------|----------|------------------------------------------------------------------------|
|
||||||
|
| Feature Request | P0 | Must respond within **5 days** |
|
||||||
|
| Process | P0 | Must respond within **5 days** |
|
||||||
|
| Bugs | P0 | Must respond within **5 days**, and resolve/closure within **14 days** |
|
||||||
|
| Bugs | P1 | Must respond within **7 days**, and resolve/closure within **90 days** |
|
||||||
|
| Bugs | P2 | Must respond within **30 days**
|
||||||
|
|
||||||
|
_Types that are not listed in the table do not adhere to any SLO._
|
||||||
|
|
||||||
### Releasing
|
### Releasing
|
||||||
|
|
||||||
Toolbox has two types of releases: versioned and continuous. It uses Google
|
Toolbox has two types of releases: versioned and continuous. It uses Google
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ To run Toolbox from binary:
|
|||||||
To run the server after pulling the [container image](#installing-the-server):
|
To run the server after pulling the [container image](#installing-the-server):
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
export VERSION=0.11.0 # Use the version you pulled
|
export VERSION=0.24.0 # Use the version you pulled
|
||||||
docker run -p 5000:5000 \
|
docker run -p 5000:5000 \
|
||||||
-v $(pwd)/tools.yaml:/app/tools.yaml \
|
-v $(pwd)/tools.yaml:/app/tools.yaml \
|
||||||
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
|
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \
|
||||||
|
|||||||
@@ -92,11 +92,13 @@ import (
|
|||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistdatabases"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqllistinstances"
|
||||||
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmssql/cloudsqlmssqlcreateinstance"
|
||||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance"
|
||||||
|
|||||||
@@ -1493,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_postgres_admin_tools",
|
Name: "cloud_sql_postgres_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1503,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_mysql_admin_tools",
|
Name: "cloud_sql_mysql_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1513,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
|
|||||||
wantToolset: server.ToolsetConfigs{
|
wantToolset: server.ToolsetConfigs{
|
||||||
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
|
||||||
Name: "cloud_sql_mssql_admin_tools",
|
Name: "cloud_sql_mssql_admin_tools",
|
||||||
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
|
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup", "restore_backup"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -48,11 +48,13 @@ instance, database and users:
|
|||||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
|||||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
|
||||||
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -48,11 +48,13 @@ database and users:
|
|||||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
|||||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
|
||||||
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -48,11 +48,13 @@ instance, database and users:
|
|||||||
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* `roles/cloudsql.admin`: Provides full control over all resources.
|
* `roles/cloudsql.admin`: Provides full control over all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
## Install MCP Toolbox
|
## Install MCP Toolbox
|
||||||
|
|
||||||
@@ -299,6 +301,8 @@ instances and interacting with your database:
|
|||||||
* **create_user**: Creates a new user in a Cloud SQL instance.
|
* **create_user**: Creates a new user in a Cloud SQL instance.
|
||||||
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
|
||||||
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
|
||||||
|
* **create_backup**: Creates a backup on a Cloud SQL instance.
|
||||||
|
* **restore_backup**: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
{{< notice note >}}
|
{{< notice note >}}
|
||||||
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ The native SDKs can be combined with MCP clients in many cases.
|
|||||||
|
|
||||||
Toolbox currently supports the following versions of MCP specification:
|
Toolbox currently supports the following versions of MCP specification:
|
||||||
|
|
||||||
|
* [2025-11-25](https://modelcontextprotocol.io/specification/2025-11-25)
|
||||||
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
|
* [2025-06-18](https://modelcontextprotocol.io/specification/2025-06-18)
|
||||||
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
|
* [2025-03-26](https://modelcontextprotocol.io/specification/2025-03-26)
|
||||||
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)
|
* [2024-11-05](https://modelcontextprotocol.io/specification/2024-11-05)
|
||||||
|
|||||||
@@ -187,12 +187,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
manage existing resources.
|
manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||||
all resources.
|
all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
|
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
|
* `create_instance`: Creates a new Cloud SQL for MySQL instance.
|
||||||
@@ -203,6 +205,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
|
||||||
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Cloud SQL for PostgreSQL
|
## Cloud SQL for PostgreSQL
|
||||||
|
|
||||||
@@ -275,12 +279,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
manage existing resources.
|
manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||||
all resources.
|
all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
|
* `create_instance`: Creates a new Cloud SQL for PostgreSQL instance.
|
||||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||||
@@ -290,6 +296,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
|
||||||
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Cloud SQL for SQL Server
|
## Cloud SQL for SQL Server
|
||||||
|
|
||||||
@@ -336,12 +344,14 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
manage existing resources.
|
manage existing resources.
|
||||||
* All `viewer` tools
|
* All `viewer` tools
|
||||||
* `create_database`
|
* `create_database`
|
||||||
|
* `create_backup`
|
||||||
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
|
||||||
all resources.
|
all resources.
|
||||||
* All `editor` and `viewer` tools
|
* All `editor` and `viewer` tools
|
||||||
* `create_instance`
|
* `create_instance`
|
||||||
* `create_user`
|
* `create_user`
|
||||||
* `clone_instance`
|
* `clone_instance`
|
||||||
|
* `restore_backup`
|
||||||
* **Tools:**
|
* **Tools:**
|
||||||
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
|
* `create_instance`: Creates a new Cloud SQL for SQL Server instance.
|
||||||
* `get_instance`: Gets information about a Cloud SQL instance.
|
* `get_instance`: Gets information about a Cloud SQL instance.
|
||||||
@@ -351,6 +361,8 @@ See [Usage Examples](../reference/cli.md#examples).
|
|||||||
* `create_user`: Creates a new user in a Cloud SQL instance.
|
* `create_user`: Creates a new user in a Cloud SQL instance.
|
||||||
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
|
||||||
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
|
||||||
|
* `create_backup`: Creates a backup on a Cloud SQL instance.
|
||||||
|
* `restore_backup`: Restores a backup of a Cloud SQL instance.
|
||||||
|
|
||||||
## Dataplex
|
## Dataplex
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ sources:
|
|||||||
# scopes: # Optional: List of OAuth scopes to request.
|
# scopes: # Optional: List of OAuth scopes to request.
|
||||||
# - "https://www.googleapis.com/auth/bigquery"
|
# - "https://www.googleapis.com/auth/bigquery"
|
||||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||||
|
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||||
```
|
```
|
||||||
|
|
||||||
Initialize a BigQuery source that uses the client's access token:
|
Initialize a BigQuery source that uses the client's access token:
|
||||||
@@ -153,6 +154,7 @@ sources:
|
|||||||
# scopes: # Optional: List of OAuth scopes to request.
|
# scopes: # Optional: List of OAuth scopes to request.
|
||||||
# - "https://www.googleapis.com/auth/bigquery"
|
# - "https://www.googleapis.com/auth/bigquery"
|
||||||
# - "https://www.googleapis.com/auth/drive.readonly"
|
# - "https://www.googleapis.com/auth/drive.readonly"
|
||||||
|
# maxQueryResultRows: 50 # Optional: Limits the number of rows returned by queries. Defaults to 50.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reference
|
## Reference
|
||||||
@@ -167,3 +169,4 @@ sources:
|
|||||||
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
| useClientOAuth | bool | false | If true, forwards the client's OAuth access token from the "Authorization" header to downstream queries. **Note:** This cannot be used with `writeMode: protected`. |
|
||||||
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
| scopes | []string | false | A list of OAuth 2.0 scopes to use for the credentials. If not provided, default scopes are used. |
|
||||||
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
| impersonateServiceAccount | string | false | Service account email to impersonate when making BigQuery and Dataplex API calls. The authenticated principal must have the `roles/iam.serviceAccountTokenCreator` role on the target service account. [Learn More](https://cloud.google.com/iam/docs/service-account-impersonation) |
|
||||||
|
| maxQueryResultRows | int | false | The maximum number of rows to return from a query. Defaults to 50. |
|
||||||
|
|||||||
@@ -7,6 +7,17 @@ description: >
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
{{< notice note >}}
|
||||||
|
**⚠️ Best Effort Maintenance**
|
||||||
|
|
||||||
|
This integration is maintained on a best-effort basis by the project
|
||||||
|
team/community. While we strive to address issues and provide workarounds when
|
||||||
|
resources are available, there are no guaranteed response times or code fixes.
|
||||||
|
|
||||||
|
The automated integration tests for this module are currently non-functional or
|
||||||
|
failing.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
## About
|
## About
|
||||||
|
|
||||||
[Dgraph][dgraph-docs] is an open-source graph database. It is designed for
|
[Dgraph][dgraph-docs] is an open-source graph database. It is designed for
|
||||||
|
|||||||
@@ -91,8 +91,8 @@ visible to the LLM.
|
|||||||
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
|
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
|
||||||
|
|
||||||
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
|
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
|
||||||
before running this tool. You can do so by running this command in the AlloyDB
|
to utilize PSV feature (`nlConfigParameters`) with this tool. You can do so by
|
||||||
studio:
|
running this command in the AlloyDB studio:
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
CREATE EXTENSION IF NOT EXISTS parameterized_views;
|
CREATE EXTENSION IF NOT EXISTS parameterized_views;
|
||||||
|
|||||||
@@ -41,13 +41,13 @@ tools:
|
|||||||
|
|
||||||
### Usage Flow
|
### Usage Flow
|
||||||
|
|
||||||
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||||
|
|
||||||
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
||||||
|
|
||||||
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
||||||
|
|
||||||
**Example Input Prompt:**
|
**Example Input Query:**
|
||||||
|
|
||||||
```text
|
```text
|
||||||
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
||||||
|
|||||||
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
45
docs/en/resources/tools/cloudsql/cloudsqlcreatebackup.md
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
---
|
||||||
|
title: cloud-sql-create-backup
|
||||||
|
type: docs
|
||||||
|
weight: 10
|
||||||
|
description: "Creates a backup on a Cloud SQL instance."
|
||||||
|
---
|
||||||
|
|
||||||
|
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||||
|
|
||||||
|
{{< notice info dd>}}
|
||||||
|
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Basic backup creation (current state)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
backup-creation-basic:
|
||||||
|
kind: cloud-sql-create-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
description: "Creates a backup on the given Cloud SQL instance."
|
||||||
|
```
|
||||||
|
## Reference
|
||||||
|
### Tool Configuration
|
||||||
|
| **field** | **type** | **required** | **description** |
|
||||||
|
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
|
||||||
|
| kind | string | true | Must be "cloud-sql-create-backup". |
|
||||||
|
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||||
|
| description | string | false | A description of the tool. |
|
||||||
|
|
||||||
|
### Tool Inputs
|
||||||
|
|
||||||
|
| **parameter** | **type** | **required** | **description** |
|
||||||
|
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
|
||||||
|
| project | string | true | The project ID. |
|
||||||
|
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
|
||||||
|
| location | string | false | (Optional) Location of the backup run. |
|
||||||
|
| backup_description | string | false | (Optional) The description of this backup run. |
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||||
|
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||||
|
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)
|
||||||
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
53
docs/en/resources/tools/cloudsql/cloudsqlrestorebackup.md
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
---
|
||||||
|
title: cloud-sql-restore-backup
|
||||||
|
type: docs
|
||||||
|
weight: 10
|
||||||
|
description: "Restores a backup of a Cloud SQL instance."
|
||||||
|
---
|
||||||
|
|
||||||
|
The `cloud-sql-restore-backup` tool restores a backup on a Cloud SQL instance using the Cloud SQL Admin API.
|
||||||
|
|
||||||
|
{{< notice info dd>}}
|
||||||
|
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Basic backup restore
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
tools:
|
||||||
|
backup-restore-basic:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
description: "Restores a backup onto the given Cloud SQL instance."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
### Tool Configuration
|
||||||
|
| **field** | **type** | **required** | **description** |
|
||||||
|
| -------------- | :------: | :----------: | ------------------------------------------------ |
|
||||||
|
| kind | string | true | Must be "cloud-sql-restore-backup". |
|
||||||
|
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||||
|
| description | string | false | A description of the tool. |
|
||||||
|
|
||||||
|
### Tool Inputs
|
||||||
|
|
||||||
|
| **parameter** | **type** | **required** | **description** |
|
||||||
|
| ------------------| :------: | :----------: | -----------------------------------------------------------------------------|
|
||||||
|
| target_project | string | true | The project ID of the instance to restore the backup onto. |
|
||||||
|
| target_instance | string | true | The instance to restore the backup onto. Does not include the project ID. |
|
||||||
|
| backup_id | string | true | The identifier of the backup being restored. |
|
||||||
|
| source_project | string | false | (Optional) The project ID of the instance that the backup belongs to. |
|
||||||
|
| source_instance | string | false | (Optional) Cloud SQL instance ID of the instance that the backup belongs to. |
|
||||||
|
|
||||||
|
## Usage Notes
|
||||||
|
|
||||||
|
- The `backup_id` field can be a BackupRun ID (which will be an int64), backup name, or BackupDR backup name.
|
||||||
|
- If the `backup_id` field contains a BackupRun ID (i.e. an int64), the optional fields `source_project` and `source_instance` must also be provided.
|
||||||
|
|
||||||
|
## See Also
|
||||||
|
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
|
||||||
|
- [Toolbox Cloud SQL tools documentation](../cloudsql)
|
||||||
|
- [Cloud SQL Restore API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/restoring)
|
||||||
@@ -9,6 +9,17 @@ aliases:
|
|||||||
- /resources/tools/dgraph-dql
|
- /resources/tools/dgraph-dql
|
||||||
---
|
---
|
||||||
|
|
||||||
|
{{< notice note >}}
|
||||||
|
**⚠️ Best Effort Maintenance**
|
||||||
|
|
||||||
|
This integration is maintained on a best-effort basis by the project
|
||||||
|
team/community. While we strive to address issues and provide workarounds when
|
||||||
|
resources are available, there are no guaranteed response times or code fixes.
|
||||||
|
|
||||||
|
The automated integration tests for this module are currently non-functional or
|
||||||
|
failing.
|
||||||
|
{{< /notice >}}
|
||||||
|
|
||||||
## About
|
## About
|
||||||
|
|
||||||
A `dgraph-dql` tool executes a pre-defined DQL statement against a Dgraph
|
A `dgraph-dql` tool executes a pre-defined DQL statement against a Dgraph
|
||||||
|
|||||||
7
docs/en/samples/neo4j/_index.md
Normal file
7
docs/en/samples/neo4j/_index.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
---
|
||||||
|
title: "Neo4j"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
How to get started with Toolbox using Neo4j.
|
||||||
|
---
|
||||||
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
141
docs/en/samples/neo4j/mcp_quickstart.md
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
---
|
||||||
|
title: "Quickstart (MCP with Neo4j)"
|
||||||
|
type: docs
|
||||||
|
weight: 1
|
||||||
|
description: >
|
||||||
|
How to get started running Toolbox with MCP Inspector and Neo4j as the source.
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol that standardizes how applications provide context to LLMs. Check out this page on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
|
||||||
|
|
||||||
|
|
||||||
|
## Step 1: Set up your Neo4j Database and Data
|
||||||
|
|
||||||
|
In this section, you'll set up a database and populate it with sample data for a movies-related agent. This guide assumes you have a running Neo4j instance, either locally or in the cloud.
|
||||||
|
|
||||||
|
. **Populate the database with data.**
|
||||||
|
To make this quickstart straightforward, we'll use the built-in Movies dataset available in Neo4j.
|
||||||
|
|
||||||
|
. In your Neo4j Browser, run the following command to create and populate the database:
|
||||||
|
+
|
||||||
|
```cypher
|
||||||
|
:play movies
|
||||||
|
````
|
||||||
|
|
||||||
|
. Follow the instructions to load the data. This will create a graph with `Movie`, `Person`, and `Actor` nodes and their relationships.
|
||||||
|
|
||||||
|
|
||||||
|
## Step 2: Install and configure Toolbox
|
||||||
|
|
||||||
|
In this section, we will install the MCP Toolbox, configure our tools in a `tools.yaml` file, and then run the Toolbox server.
|
||||||
|
|
||||||
|
. **Install the Toolbox binary.**
|
||||||
|
The simplest way to get started is to download the latest binary for your operating system.
|
||||||
|
|
||||||
|
. Download the latest version of Toolbox as a binary:
|
||||||
|
\+
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||||
|
curl -O [https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox](https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox)
|
||||||
|
```
|
||||||
|
|
||||||
|
+
|
||||||
|
. Make the binary executable:
|
||||||
|
\+
|
||||||
|
|
||||||
|
```bash
|
||||||
|
chmod +x toolbox
|
||||||
|
```
|
||||||
|
|
||||||
|
. **Create the `tools.yaml` file.**
|
||||||
|
This file defines your Neo4j source and the specific tools that will be exposed to your AI agent.
|
||||||
|
\+
|
||||||
|
{{\< notice tip \>}}
|
||||||
|
Authentication for the Neo4j source uses standard username and password fields. For production use, it is highly recommended to use environment variables for sensitive information like passwords.
|
||||||
|
{{\< /notice \>}}
|
||||||
|
\+
|
||||||
|
Write the following into a `tools.yaml` file:
|
||||||
|
\+
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sources:
|
||||||
|
my-neo4j-source:
|
||||||
|
kind: neo4j
|
||||||
|
uri: bolt://localhost:7687
|
||||||
|
user: neo4j
|
||||||
|
password: my-password # Replace with your actual password
|
||||||
|
|
||||||
|
tools:
|
||||||
|
search-movies-by-actor:
|
||||||
|
kind: neo4j-cypher
|
||||||
|
source: my-neo4j-source
|
||||||
|
description: "Searches for movies an actor has appeared in based on their name. Useful for questions like 'What movies has Tom Hanks been in?'"
|
||||||
|
parameters:
|
||||||
|
- name: actor_name
|
||||||
|
type: string
|
||||||
|
description: The full name of the actor to search for.
|
||||||
|
statement: |
|
||||||
|
MATCH (p:Person {name: $actor_name}) -[:ACTED_IN]-> (m:Movie)
|
||||||
|
RETURN m.title AS title, m.year AS year, m.genre AS genre
|
||||||
|
|
||||||
|
get-actor-for-movie:
|
||||||
|
kind: neo4j-cypher
|
||||||
|
source: my-neo4j-source
|
||||||
|
description: "Finds the actors who starred in a specific movie. Useful for questions like 'Who acted in Inception?'"
|
||||||
|
parameters:
|
||||||
|
- name: movie_title
|
||||||
|
type: string
|
||||||
|
description: The exact title of the movie.
|
||||||
|
statement: |
|
||||||
|
MATCH (p:Person) -[:ACTED_IN]-> (m:Movie {title: $movie_title})
|
||||||
|
RETURN p.name AS actor
|
||||||
|
```
|
||||||
|
|
||||||
|
. **Start the Toolbox server.**
|
||||||
|
Run the Toolbox server, pointing to the `tools.yaml` file you created earlier.
|
||||||
|
\+
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./toolbox --tools-file "tools.yaml"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 3: Connect to MCP Inspector
|
||||||
|
|
||||||
|
. **Run the MCP Inspector:**
|
||||||
|
\+
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npx @modelcontextprotocol/inspector
|
||||||
|
```
|
||||||
|
|
||||||
|
. Type `y` when it asks to install the inspector package.
|
||||||
|
. It should show the following when the MCP Inspector is up and running (please take note of `<YOUR_SESSION_TOKEN>`):
|
||||||
|
\+
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Starting MCP inspector...
|
||||||
|
⚙️ Proxy server listening on localhost:6277
|
||||||
|
🔑 Session token: <YOUR_SESSION_TOKEN>
|
||||||
|
Use this token to authenticate requests or set DANGEROUSLY_OMIT_AUTH=true to disable auth
|
||||||
|
|
||||||
|
🚀 MCP Inspector is up and running at:
|
||||||
|
http://localhost:6274/?MCP_PROXY_AUTH_TOKEN=<YOUR_SESSION_TOKEN>
|
||||||
|
```
|
||||||
|
|
||||||
|
1. Open the above link in your browser.
|
||||||
|
|
||||||
|
1. For `Transport Type`, select `Streamable HTTP`.
|
||||||
|
|
||||||
|
1. For `URL`, type in `http://127.0.0.1:5000/mcp`.
|
||||||
|
|
||||||
|
1. For `Configuration` -\> `Proxy Session Token`, make sure `<YOUR_SESSION_TOKEN>` is present.
|
||||||
|
|
||||||
|
1. Click `Connect`.
|
||||||
|
|
||||||
|
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
|
||||||
|
|
||||||
|
1. Test out your tools here\!
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ sources:
|
|||||||
location: ${BIGQUERY_LOCATION:}
|
location: ${BIGQUERY_LOCATION:}
|
||||||
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
useClientOAuth: ${BIGQUERY_USE_CLIENT_OAUTH:false}
|
||||||
scopes: ${BIGQUERY_SCOPES:}
|
scopes: ${BIGQUERY_SCOPES:}
|
||||||
|
maxQueryResultRows: ${BIGQUERY_MAX_QUERY_RESULT_ROWS:50}
|
||||||
|
|
||||||
tools:
|
tools:
|
||||||
analyze_contribution:
|
analyze_contribution:
|
||||||
|
|||||||
@@ -43,6 +43,12 @@ tools:
|
|||||||
clone_instance:
|
clone_instance:
|
||||||
kind: cloud-sql-clone-instance
|
kind: cloud-sql-clone-instance
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
create_backup:
|
||||||
|
kind: cloud-sql-create-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_mssql_admin_tools:
|
cloud_sql_mssql_admin_tools:
|
||||||
@@ -54,3 +60,5 @@ toolsets:
|
|||||||
- create_user
|
- create_user
|
||||||
- wait_for_operation
|
- wait_for_operation
|
||||||
- clone_instance
|
- clone_instance
|
||||||
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -43,6 +43,12 @@ tools:
|
|||||||
clone_instance:
|
clone_instance:
|
||||||
kind: cloud-sql-clone-instance
|
kind: cloud-sql-clone-instance
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
create_backup:
|
||||||
|
kind: cloud-sql-create-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_mysql_admin_tools:
|
cloud_sql_mysql_admin_tools:
|
||||||
@@ -54,3 +60,5 @@ toolsets:
|
|||||||
- create_user
|
- create_user
|
||||||
- wait_for_operation
|
- wait_for_operation
|
||||||
- clone_instance
|
- clone_instance
|
||||||
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -46,6 +46,12 @@ tools:
|
|||||||
postgres_upgrade_precheck:
|
postgres_upgrade_precheck:
|
||||||
kind: postgres-upgrade-precheck
|
kind: postgres-upgrade-precheck
|
||||||
source: cloud-sql-admin-source
|
source: cloud-sql-admin-source
|
||||||
|
create_backup:
|
||||||
|
kind: cloud-sql-create-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
restore_backup:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
source: cloud-sql-admin-source
|
||||||
|
|
||||||
toolsets:
|
toolsets:
|
||||||
cloud_sql_postgres_admin_tools:
|
cloud_sql_postgres_admin_tools:
|
||||||
@@ -58,3 +64,5 @@ toolsets:
|
|||||||
- wait_for_operation
|
- wait_for_operation
|
||||||
- postgres_upgrade_precheck
|
- postgres_upgrade_precheck
|
||||||
- clone_instance
|
- clone_instance
|
||||||
|
- create_backup
|
||||||
|
- restore_backup
|
||||||
|
|||||||
@@ -27,19 +27,21 @@ import (
|
|||||||
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
v20241105 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20241105"
|
||||||
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
v20250326 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250326"
|
||||||
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
|
v20250618 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20250618"
|
||||||
|
v20251125 "github.com/googleapis/genai-toolbox/internal/server/mcp/v20251125"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
|
// LATEST_PROTOCOL_VERSION is the latest version of the MCP protocol supported.
|
||||||
// Update the version used in InitializeResponse when this value is updated.
|
// Update the version used in InitializeResponse when this value is updated.
|
||||||
const LATEST_PROTOCOL_VERSION = v20250618.PROTOCOL_VERSION
|
const LATEST_PROTOCOL_VERSION = v20251125.PROTOCOL_VERSION
|
||||||
|
|
||||||
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
|
// SUPPORTED_PROTOCOL_VERSIONS is the MCP protocol versions that are supported.
|
||||||
var SUPPORTED_PROTOCOL_VERSIONS = []string{
|
var SUPPORTED_PROTOCOL_VERSIONS = []string{
|
||||||
v20241105.PROTOCOL_VERSION,
|
v20241105.PROTOCOL_VERSION,
|
||||||
v20250326.PROTOCOL_VERSION,
|
v20250326.PROTOCOL_VERSION,
|
||||||
v20250618.PROTOCOL_VERSION,
|
v20250618.PROTOCOL_VERSION,
|
||||||
|
v20251125.PROTOCOL_VERSION,
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitializeResponse runs capability negotiation and protocol version agreement.
|
// InitializeResponse runs capability negotiation and protocol version agreement.
|
||||||
@@ -102,6 +104,8 @@ func NotificationHandler(ctx context.Context, body []byte) error {
|
|||||||
// This is the Operation phase of the lifecycle for MCP client-server connections.
|
// This is the Operation phase of the lifecycle for MCP client-server connections.
|
||||||
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
func ProcessMethod(ctx context.Context, mcpVersion string, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||||
switch mcpVersion {
|
switch mcpVersion {
|
||||||
|
case v20251125.PROTOCOL_VERSION:
|
||||||
|
return v20251125.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
|
||||||
case v20250618.PROTOCOL_VERSION:
|
case v20250618.PROTOCOL_VERSION:
|
||||||
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
|
return v20250618.ProcessMethod(ctx, id, method, toolset, promptset, resourceMgr, body, header)
|
||||||
case v20250326.PROTOCOL_VERSION:
|
case v20250326.PROTOCOL_VERSION:
|
||||||
|
|||||||
326
internal/server/mcp/v20251125/method.go
Normal file
326
internal/server/mcp/v20251125/method.go
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package v20251125
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessMethod returns a response for the request.
|
||||||
|
func ProcessMethod(ctx context.Context, id jsonrpc.RequestId, method string, toolset tools.Toolset, promptset prompts.Promptset, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||||
|
switch method {
|
||||||
|
case PING:
|
||||||
|
return pingHandler(id)
|
||||||
|
case TOOLS_LIST:
|
||||||
|
return toolsListHandler(id, toolset, body)
|
||||||
|
case TOOLS_CALL:
|
||||||
|
return toolsCallHandler(ctx, id, resourceMgr, body, header)
|
||||||
|
case PROMPTS_LIST:
|
||||||
|
return promptsListHandler(ctx, id, promptset, body)
|
||||||
|
case PROMPTS_GET:
|
||||||
|
return promptsGetHandler(ctx, id, resourceMgr, body)
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("invalid method %s", method)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.METHOD_NOT_FOUND, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pingHandler handles the "ping" method by returning an empty response.
|
||||||
|
func pingHandler(id jsonrpc.RequestId) (any, error) {
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: struct{}{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolsListHandler(id jsonrpc.RequestId, toolset tools.Toolset, body []byte) (any, error) {
|
||||||
|
var req ListToolsRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
err = fmt.Errorf("invalid mcp tools list request: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ListToolsResult{
|
||||||
|
Tools: toolset.McpManifest,
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: result,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// toolsCallHandler generate a response for tools call.
|
||||||
|
func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte, header http.Header) (any, error) {
|
||||||
|
authServices := resourceMgr.GetAuthServiceMap()
|
||||||
|
|
||||||
|
// retrieve logger from context
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
var req CallToolRequest
|
||||||
|
if err = json.Unmarshal(body, &req); err != nil {
|
||||||
|
err = fmt.Errorf("invalid mcp tools call request: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
toolName := req.Params.Name
|
||||||
|
toolArgument := req.Params.Arguments
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("tool name: %s", toolName))
|
||||||
|
tool, ok := resourceMgr.GetTool(toolName)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("invalid tool name: tool with name %q does not exist", toolName)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get access token
|
||||||
|
authTokenHeadername, err := tool.GetAuthTokenHeaderName(resourceMgr)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||||
|
}
|
||||||
|
accessToken := tools.AccessToken(header.Get(authTokenHeadername))
|
||||||
|
|
||||||
|
// Check if this specific tool requires the standard authorization header
|
||||||
|
clientAuth, err := tool.RequiresClientAuthorization(resourceMgr)
|
||||||
|
if err != nil {
|
||||||
|
errMsg := fmt.Errorf("error during invocation: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, errMsg.Error(), nil), errMsg
|
||||||
|
}
|
||||||
|
if clientAuth {
|
||||||
|
if accessToken == "" {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, "missing access token in the 'Authorization' header", nil), util.ErrUnauthorized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshal arguments and decode it using decodeJSON instead to prevent loss between floats/int.
|
||||||
|
aMarshal, err := json.Marshal(toolArgument)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("unable to marshal tools argument: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
var data map[string]any
|
||||||
|
if err = util.DecodeJSON(bytes.NewBuffer(aMarshal), &data); err != nil {
|
||||||
|
err = fmt.Errorf("unable to decode tools argument: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool authentication
|
||||||
|
// claimsFromAuth maps the name of the authservice to the claims retrieved from it.
|
||||||
|
claimsFromAuth := make(map[string]map[string]any)
|
||||||
|
|
||||||
|
// if using stdio, header will be nil and auth will not be supported
|
||||||
|
if header != nil {
|
||||||
|
for _, aS := range authServices {
|
||||||
|
claims, err := aS.GetClaimsFromHeader(ctx, header)
|
||||||
|
if err != nil {
|
||||||
|
logger.DebugContext(ctx, err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if claims == nil {
|
||||||
|
// authService not present in header
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
claimsFromAuth[aS.GetName()] = claims
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool authorization check
|
||||||
|
verifiedAuthServices := make([]string, len(claimsFromAuth))
|
||||||
|
i := 0
|
||||||
|
for k := range claimsFromAuth {
|
||||||
|
verifiedAuthServices[i] = k
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any of the specified auth services is verified
|
||||||
|
isAuthorized := tool.Authorized(verifiedAuthServices)
|
||||||
|
if !isAuthorized {
|
||||||
|
err = fmt.Errorf("unauthorized Tool call: Please make sure your specify correct auth headers: %w", util.ErrUnauthorized)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, "tool invocation authorized")
|
||||||
|
|
||||||
|
params, err := tool.ParseParams(data, claimsFromAuth)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("provided parameters were invalid: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("invocation params: %s", params))
|
||||||
|
|
||||||
|
// run tool invocation and generate response.
|
||||||
|
results, err := tool.Invoke(ctx, resourceMgr, params, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
errStr := err.Error()
|
||||||
|
// Missing authService tokens.
|
||||||
|
if errors.Is(err, util.ErrUnauthorized) {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Upstream auth error
|
||||||
|
if strings.Contains(errStr, "Error 401") || strings.Contains(errStr, "Error 403") {
|
||||||
|
if clientAuth {
|
||||||
|
// Error with client credentials should pass down to the client
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
// Auth error with ADC should raise internal 500 error
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
text := TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: err.Error(),
|
||||||
|
}
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: []TextContent{text}, IsError: true},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
content := make([]TextContent, 0)
|
||||||
|
|
||||||
|
sliceRes, ok := results.([]any)
|
||||||
|
if !ok {
|
||||||
|
sliceRes = []any{results}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, d := range sliceRes {
|
||||||
|
text := TextContent{Type: "text"}
|
||||||
|
dM, err := json.Marshal(d)
|
||||||
|
if err != nil {
|
||||||
|
text.Text = fmt.Sprintf("fail to marshal: %s, result: %s", err, d)
|
||||||
|
} else {
|
||||||
|
text.Text = string(dM)
|
||||||
|
}
|
||||||
|
content = append(content, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: CallToolResult{Content: content},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptsListHandler handles the "prompts/list" method.
|
||||||
|
func promptsListHandler(ctx context.Context, id jsonrpc.RequestId, promptset prompts.Promptset, body []byte) (any, error) {
|
||||||
|
// retrieve logger from context
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, "handling prompts/list request")
|
||||||
|
|
||||||
|
var req ListPromptsRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
err = fmt.Errorf("invalid mcp prompts list request: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ListPromptsResult{
|
||||||
|
Prompts: promptset.McpManifest,
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("returning %d prompts", len(promptset.McpManifest)))
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: result,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// promptsGetHandler handles the "prompts/get" method.
|
||||||
|
func promptsGetHandler(ctx context.Context, id jsonrpc.RequestId, resourceMgr *resources.ResourceManager, body []byte) (any, error) {
|
||||||
|
// retrieve logger from context
|
||||||
|
logger, err := util.LoggerFromContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, "handling prompts/get request")
|
||||||
|
|
||||||
|
var req GetPromptRequest
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
err = fmt.Errorf("invalid mcp prompts/get request: %w", err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_REQUEST, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
promptName := req.Params.Name
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("prompt name: %s", promptName))
|
||||||
|
prompt, ok := resourceMgr.GetPrompt(promptName)
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("prompt with name %q does not exist", promptName)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the arguments provided in the request.
|
||||||
|
argValues, err := prompt.ParseArgs(req.Params.Arguments, nil)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("invalid arguments for prompt %q: %w", promptName, err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INVALID_PARAMS, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, fmt.Sprintf("parsed args: %v", argValues))
|
||||||
|
|
||||||
|
// Substitute the argument values into the prompt's messages.
|
||||||
|
substituted, err := prompt.SubstituteParams(argValues)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("error substituting params for prompt %q: %w", promptName, err)
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cast the result to the expected []prompts.Message type.
|
||||||
|
substitutedMessages, ok := substituted.([]prompts.Message)
|
||||||
|
if !ok {
|
||||||
|
err = fmt.Errorf("internal error: SubstituteParams returned unexpected type")
|
||||||
|
return jsonrpc.NewError(id, jsonrpc.INTERNAL_ERROR, err.Error(), nil), err
|
||||||
|
}
|
||||||
|
logger.DebugContext(ctx, "substituted params successfully")
|
||||||
|
|
||||||
|
// Format the response messages into the required structure.
|
||||||
|
promptMessages := make([]PromptMessage, len(substitutedMessages))
|
||||||
|
for i, msg := range substitutedMessages {
|
||||||
|
promptMessages[i] = PromptMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: TextContent{
|
||||||
|
Type: "text",
|
||||||
|
Text: msg.Content,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := GetPromptResult{
|
||||||
|
Description: prompt.Manifest().Description,
|
||||||
|
Messages: promptMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonrpc.JSONRPCResponse{
|
||||||
|
Jsonrpc: jsonrpc.JSONRPC_VERSION,
|
||||||
|
Id: id,
|
||||||
|
Result: result,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
219
internal/server/mcp/v20251125/types.go
Normal file
219
internal/server/mcp/v20251125/types.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package v20251125
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SERVER_NAME is the server name used in Implementation.
|
||||||
|
const SERVER_NAME = "Toolbox"
|
||||||
|
|
||||||
|
// PROTOCOL_VERSION is the version of the MCP protocol in this package.
|
||||||
|
const PROTOCOL_VERSION = "2025-11-25"
|
||||||
|
|
||||||
|
// methods that are supported.
|
||||||
|
const (
|
||||||
|
PING = "ping"
|
||||||
|
TOOLS_LIST = "tools/list"
|
||||||
|
TOOLS_CALL = "tools/call"
|
||||||
|
PROMPTS_LIST = "prompts/list"
|
||||||
|
PROMPTS_GET = "prompts/get"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Empty result */
|
||||||
|
|
||||||
|
// EmptyResult represents a response that indicates success but carries no data.
|
||||||
|
type EmptyResult jsonrpc.Result
|
||||||
|
|
||||||
|
/* Pagination */
|
||||||
|
|
||||||
|
// Cursor is an opaque token used to represent a cursor for pagination.
|
||||||
|
type Cursor string
|
||||||
|
|
||||||
|
type PaginatedRequest struct {
|
||||||
|
jsonrpc.Request
|
||||||
|
Params struct {
|
||||||
|
// An opaque token representing the current pagination position.
|
||||||
|
// If provided, the server should return results starting after this cursor.
|
||||||
|
Cursor Cursor `json:"cursor,omitempty"`
|
||||||
|
} `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type PaginatedResult struct {
|
||||||
|
jsonrpc.Result
|
||||||
|
// An opaque token representing the pagination position after the last returned result.
|
||||||
|
// If present, there may be more results available.
|
||||||
|
NextCursor Cursor `json:"nextCursor,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tools */
|
||||||
|
|
||||||
|
// Sent from the client to request a list of tools the server has.
|
||||||
|
type ListToolsRequest struct {
|
||||||
|
PaginatedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// The server's response to a tools/list request from the client.
|
||||||
|
type ListToolsResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Tools []tools.McpManifest `json:"tools"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by the client to invoke a tool provided by the server.
|
||||||
|
type CallToolRequest struct {
|
||||||
|
jsonrpc.Request
|
||||||
|
Params struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
} `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// The sender or recipient of messages and data in a conversation.
|
||||||
|
type Role string
|
||||||
|
|
||||||
|
const (
|
||||||
|
RoleUser Role = "user"
|
||||||
|
RoleAssistant Role = "assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Base for objects that include optional annotations for the client.
|
||||||
|
// The client can use annotations to inform how objects are used or displayed
|
||||||
|
type Annotated struct {
|
||||||
|
Annotations *struct {
|
||||||
|
// Describes who the intended customer of this object or data is.
|
||||||
|
// It can include multiple entries to indicate content useful for multiple
|
||||||
|
// audiences (e.g., `["user", "assistant"]`).
|
||||||
|
Audience []Role `json:"audience,omitempty"`
|
||||||
|
// Describes how important this data is for operating the server.
|
||||||
|
//
|
||||||
|
// A value of 1 means "most important," and indicates that the data is
|
||||||
|
// effectively required, while 0 means "least important," and indicates that
|
||||||
|
// the data is entirely optional.
|
||||||
|
//
|
||||||
|
// @TJS-type number
|
||||||
|
// @minimum 0
|
||||||
|
// @maximum 1
|
||||||
|
Priority float64 `json:"priority,omitempty"`
|
||||||
|
} `json:"annotations,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TextContent represents text provided to or from an LLM.
|
||||||
|
type TextContent struct {
|
||||||
|
Annotated
|
||||||
|
Type string `json:"type"`
|
||||||
|
// The text content of the message.
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// The server's response to a tool call.
|
||||||
|
//
|
||||||
|
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||||
|
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||||
|
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||||
|
// and self-correct.
|
||||||
|
//
|
||||||
|
// However, any errors in _finding_ the tool, an error indicating that the
|
||||||
|
// server does not support tool calls, or any other exceptional conditions,
|
||||||
|
// should be reported as an MCP error response.
|
||||||
|
type CallToolResult struct {
|
||||||
|
jsonrpc.Result
|
||||||
|
// Could be either a TextContent, ImageContent, or EmbeddedResources
|
||||||
|
// For Toolbox, we will only be sending TextContent
|
||||||
|
Content []TextContent `json:"content"`
|
||||||
|
// Whether the tool call ended in an error.
|
||||||
|
// If not set, this is assumed to be false (the call was successful).
|
||||||
|
//
|
||||||
|
// Any errors that originate from the tool SHOULD be reported inside the result
|
||||||
|
// object, with `isError` set to true, _not_ as an MCP protocol-level error
|
||||||
|
// response. Otherwise, the LLM would not be able to see that an error occurred
|
||||||
|
// and self-correct.
|
||||||
|
//
|
||||||
|
// However, any errors in _finding_ the tool, an error indicating that the
|
||||||
|
// server does not support tool calls, or any other exceptional conditions,
|
||||||
|
// should be reported as an MCP error response.
|
||||||
|
IsError bool `json:"isError,omitempty"`
|
||||||
|
// An optional JSON object that represents the structured result of the tool call.
|
||||||
|
StructuredContent map[string]any `json:"structuredContent,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Additional properties describing a Tool to clients.
|
||||||
|
//
|
||||||
|
// NOTE: all properties in ToolAnnotations are **hints**.
|
||||||
|
// They are not guaranteed to provide a faithful description of
|
||||||
|
// tool behavior (including descriptive properties like `title`).
|
||||||
|
//
|
||||||
|
// Clients should never make tool use decisions based on ToolAnnotations
|
||||||
|
// received from untrusted servers.
|
||||||
|
type ToolAnnotations struct {
|
||||||
|
// A human-readable title for the tool.
|
||||||
|
Title string `json:"title,omitempty"`
|
||||||
|
// If true, the tool does not modify its environment.
|
||||||
|
// Default: false
|
||||||
|
ReadOnlyHint bool `json:"readOnlyHint,omitempty"`
|
||||||
|
// If true, the tool may perform destructive updates to its environment.
|
||||||
|
// If false, the tool performs only additive updates.
|
||||||
|
// (This property is meaningful only when `readOnlyHint == false`)
|
||||||
|
// Default: true
|
||||||
|
DestructiveHint bool `json:"destructiveHint,omitempty"`
|
||||||
|
// If true, calling the tool repeatedly with the same arguments
|
||||||
|
// will have no additional effect on the its environment.
|
||||||
|
// (This property is meaningful only when `readOnlyHint == false`)
|
||||||
|
// Default: false
|
||||||
|
IdempotentHint bool `json:"idempotentHint,omitempty"`
|
||||||
|
// If true, this tool may interact with an "open world" of external
|
||||||
|
// entities. If false, the tool's domain of interaction is closed.
|
||||||
|
// For example, the world of a web search tool is open, whereas that
|
||||||
|
// of a memory tool is not.
|
||||||
|
// Default: true
|
||||||
|
OpenWorldHint bool `json:"openWorldHint,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Prompts */
|
||||||
|
|
||||||
|
// Sent from the client to request a list of prompts the server has.
|
||||||
|
type ListPromptsRequest struct {
|
||||||
|
PaginatedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
// The server's response to a prompts/list request from the client.
|
||||||
|
type ListPromptsResult struct {
|
||||||
|
PaginatedResult
|
||||||
|
Prompts []prompts.McpManifest `json:"prompts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used by the client to get a prompt provided by the server.
|
||||||
|
type GetPromptRequest struct {
|
||||||
|
jsonrpc.Request
|
||||||
|
Params struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Arguments map[string]any `json:"arguments,omitempty"`
|
||||||
|
} `json:"params"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// The server's response to a prompts/get request from the client.
|
||||||
|
type GetPromptResult struct {
|
||||||
|
jsonrpc.Result
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Messages []PromptMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes a message returned as part of a prompt.
|
||||||
|
type PromptMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content TextContent `json:"content"`
|
||||||
|
}
|
||||||
@@ -37,6 +37,7 @@ const jsonrpcVersion = "2.0"
|
|||||||
const protocolVersion20241105 = "2024-11-05"
|
const protocolVersion20241105 = "2024-11-05"
|
||||||
const protocolVersion20250326 = "2025-03-26"
|
const protocolVersion20250326 = "2025-03-26"
|
||||||
const protocolVersion20250618 = "2025-06-18"
|
const protocolVersion20250618 = "2025-06-18"
|
||||||
|
const protocolVersion20251125 = "2025-11-25"
|
||||||
const serverName = "Toolbox"
|
const serverName = "Toolbox"
|
||||||
|
|
||||||
var basicInputSchema = map[string]any{
|
var basicInputSchema = map[string]any{
|
||||||
@@ -485,6 +486,23 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "version 2025-11-25",
|
||||||
|
protocol: protocolVersion20251125,
|
||||||
|
idHeader: false,
|
||||||
|
initWant: map[string]any{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": "mcp-initialize",
|
||||||
|
"result": map[string]any{
|
||||||
|
"protocolVersion": "2025-11-25",
|
||||||
|
"capabilities": map[string]any{
|
||||||
|
"tools": map[string]any{"listChanged": false},
|
||||||
|
"prompts": map[string]any{"listChanged": false},
|
||||||
|
},
|
||||||
|
"serverInfo": map[string]any{"name": serverName, "version": fakeVersionString},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, vtc := range versTestCases {
|
for _, vtc := range versTestCases {
|
||||||
t.Run(vtc.name, func(t *testing.T) {
|
t.Run(vtc.name, func(t *testing.T) {
|
||||||
@@ -494,8 +512,7 @@ func TestMcpEndpoint(t *testing.T) {
|
|||||||
if sessionId != "" {
|
if sessionId != "" {
|
||||||
header["Mcp-Session-Id"] = sessionId
|
header["Mcp-Session-Id"] = sessionId
|
||||||
}
|
}
|
||||||
|
if vtc.protocol != protocolVersion20241105 && vtc.protocol != protocolVersion20250326 {
|
||||||
if vtc.protocol == protocolVersion20250618 {
|
|
||||||
header["MCP-Protocol-Version"] = vtc.protocol
|
header["MCP-Protocol-Version"] = vtc.protocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -304,10 +304,14 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
_, hasWildcard := allowedHosts["*"]
|
_, hasWildcard := allowedHosts["*"]
|
||||||
_, hostIsAllowed := allowedHosts[r.Host]
|
hostname := r.Host
|
||||||
|
if host, _, err := net.SplitHostPort(r.Host); err == nil {
|
||||||
|
hostname = host
|
||||||
|
}
|
||||||
|
_, hostIsAllowed := allowedHosts[hostname]
|
||||||
if !hasWildcard && !hostIsAllowed {
|
if !hasWildcard && !hostIsAllowed {
|
||||||
// Return 400 Bad Request or 403 Forbidden to block the attack
|
// Return 403 Forbidden to block the attack
|
||||||
http.Error(w, "Invalid Host header", http.StatusBadRequest)
|
http.Error(w, "Invalid Host header", http.StatusForbidden)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
@@ -406,7 +410,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
|||||||
}
|
}
|
||||||
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
|
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
|
||||||
for _, h := range cfg.AllowedHosts {
|
for _, h := range cfg.AllowedHosts {
|
||||||
allowedHostsMap[h] = struct{}{}
|
hostname := h
|
||||||
|
if host, _, err := net.SplitHostPort(h); err == nil {
|
||||||
|
hostname = host
|
||||||
|
}
|
||||||
|
allowedHostsMap[hostname] = struct{}{}
|
||||||
}
|
}
|
||||||
r.Use(hostCheck(allowedHostsMap))
|
r.Use(hostCheck(allowedHostsMap))
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ type Config struct {
|
|||||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||||
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
ImpersonateServiceAccount string `yaml:"impersonateServiceAccount"`
|
||||||
Scopes StringOrStringSlice `yaml:"scopes"`
|
Scopes StringOrStringSlice `yaml:"scopes"`
|
||||||
|
MaxQueryResultRows int `yaml:"maxQueryResultRows"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
// StringOrStringSlice is a custom type that can unmarshal both a single string
|
||||||
@@ -127,6 +128,10 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
|||||||
r.WriteMode = WriteModeAllowed
|
r.WriteMode = WriteModeAllowed
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.MaxQueryResultRows == 0 {
|
||||||
|
r.MaxQueryResultRows = 50
|
||||||
|
}
|
||||||
|
|
||||||
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
if r.WriteMode == WriteModeProtected && r.UseClientOAuth {
|
||||||
// The protected mode only allows write operations to the session's temporary datasets.
|
// The protected mode only allows write operations to the session's temporary datasets.
|
||||||
// when using client OAuth, a new session is created every
|
// when using client OAuth, a new session is created every
|
||||||
@@ -150,7 +155,7 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
|||||||
Client: client,
|
Client: client,
|
||||||
RestService: restService,
|
RestService: restService,
|
||||||
TokenSource: tokenSource,
|
TokenSource: tokenSource,
|
||||||
MaxQueryResultRows: 50,
|
MaxQueryResultRows: r.MaxQueryResultRows,
|
||||||
ClientCreator: clientCreator,
|
ClientCreator: clientCreator,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,7 +572,7 @@ func (s *Source) RunSQL(ctx context.Context, bqClient *bigqueryapi.Client, state
|
|||||||
}
|
}
|
||||||
|
|
||||||
var out []any
|
var out []any
|
||||||
for {
|
for s.MaxQueryResultRows <= 0 || len(out) < s.MaxQueryResultRows {
|
||||||
var val []bigqueryapi.Value
|
var val []bigqueryapi.Value
|
||||||
err = it.Next(&val)
|
err = it.Next(&val)
|
||||||
if err == iterator.Done {
|
if err == iterator.Done {
|
||||||
|
|||||||
@@ -21,9 +21,12 @@ import (
|
|||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
"github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseFromYamlBigQuery(t *testing.T) {
|
func TestParseFromYamlBigQuery(t *testing.T) {
|
||||||
@@ -154,6 +157,26 @@ func TestParseFromYamlBigQuery(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "with max query result rows example",
|
||||||
|
in: `
|
||||||
|
sources:
|
||||||
|
my-instance:
|
||||||
|
kind: bigquery
|
||||||
|
project: my-project
|
||||||
|
location: us
|
||||||
|
maxQueryResultRows: 10
|
||||||
|
`,
|
||||||
|
want: server.SourceConfigs{
|
||||||
|
"my-instance": bigquery.Config{
|
||||||
|
Name: "my-instance",
|
||||||
|
Kind: bigquery.SourceKind,
|
||||||
|
Project: "my-project",
|
||||||
|
Location: "us",
|
||||||
|
MaxQueryResultRows: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
@@ -220,6 +243,59 @@ func TestFailParseFromYaml(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestInitialize_MaxQueryResultRows(t *testing.T) {
|
||||||
|
ctx, err := testutils.ContextWithNewLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
ctx = util.WithUserAgent(ctx, "test-agent")
|
||||||
|
tracer := noop.NewTracerProvider().Tracer("")
|
||||||
|
|
||||||
|
tcs := []struct {
|
||||||
|
desc string
|
||||||
|
cfg bigquery.Config
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "default value",
|
||||||
|
cfg: bigquery.Config{
|
||||||
|
Name: "test-default",
|
||||||
|
Kind: bigquery.SourceKind,
|
||||||
|
Project: "test-project",
|
||||||
|
UseClientOAuth: true,
|
||||||
|
},
|
||||||
|
want: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "configured value",
|
||||||
|
cfg: bigquery.Config{
|
||||||
|
Name: "test-configured",
|
||||||
|
Kind: bigquery.SourceKind,
|
||||||
|
Project: "test-project",
|
||||||
|
UseClientOAuth: true,
|
||||||
|
MaxQueryResultRows: 100,
|
||||||
|
},
|
||||||
|
want: 100,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Initialize failed: %v", err)
|
||||||
|
}
|
||||||
|
bqSrc, ok := src.(*bigquery.Source)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected *bigquery.Source, got %T", src)
|
||||||
|
}
|
||||||
|
if bqSrc.MaxQueryResultRows != tc.want {
|
||||||
|
t.Errorf("MaxQueryResultRows = %d, want %d", bqSrc.MaxQueryResultRows, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeValue(t *testing.T) {
|
func TestNormalizeValue(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -16,8 +16,12 @@ package cloudhealthcare
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
@@ -255,3 +259,299 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
|
|||||||
func (s *Source) UseClientAuthorization() bool {
|
func (s *Source) UseClientAuthorization() bool {
|
||||||
return s.UseClientOAuth
|
return s.UseClientOAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseResults(resp *http.Response) (any, error) {
|
||||||
|
respBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode > 299 {
|
||||||
|
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||||
|
}
|
||||||
|
var jsonMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
||||||
|
}
|
||||||
|
return jsonMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
|
||||||
|
svc := s.Service()
|
||||||
|
var err error
|
||||||
|
// Initialize new service if using user OAuth token
|
||||||
|
if s.UseClientAuthorization() {
|
||||||
|
svc, err = s.ServiceCreator()(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return svc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
|
||||||
|
var httpClient *http.Client
|
||||||
|
if s.UseClientAuthorization() {
|
||||||
|
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
||||||
|
httpClient = oauth2.NewClient(ctx, ts)
|
||||||
|
} else {
|
||||||
|
// The source.Service() object holds a client with the default credentials.
|
||||||
|
// However, the client is not exported, so we have to create a new one.
|
||||||
|
var err error
|
||||||
|
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create http request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return parseResults(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
|
||||||
|
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return parseResults(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return parseResults(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||||
|
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||||
|
}
|
||||||
|
return dataset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
|
||||||
|
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
||||||
|
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
||||||
|
resp, err := call.Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
return parseResults(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||||
|
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||||
|
}
|
||||||
|
var filtered []*healthcare.DicomStore
|
||||||
|
for _, store := range stores.DicomStores {
|
||||||
|
if len(s.AllowedDICOMStores()) == 0 {
|
||||||
|
filtered = append(filtered, store)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(store.Name) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.Split(store.Name, "/")
|
||||||
|
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
||||||
|
filtered = append(filtered, store)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
|
||||||
|
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
||||||
|
}
|
||||||
|
var filtered []*healthcare.FhirStore
|
||||||
|
for _, store := range stores.FhirStores {
|
||||||
|
if len(s.AllowedFHIRStores()) == 0 {
|
||||||
|
filtered = append(filtered, store)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(store.Name) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.Split(store.Name, "/")
|
||||||
|
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
||||||
|
filtered = append(filtered, store)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
||||||
|
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
||||||
|
call.Header().Set("Accept", "image/jpeg")
|
||||||
|
resp, err := call.Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode > 299 {
|
||||||
|
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||||
|
}
|
||||||
|
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
||||||
|
return base64String, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
|
||||||
|
svc, err := s.getService(tokenStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
|
||||||
|
var resp *http.Response
|
||||||
|
switch toolKind {
|
||||||
|
case "cloud-healthcare-search-dicom-instances":
|
||||||
|
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
||||||
|
case "cloud-healthcare-search-dicom-series":
|
||||||
|
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
||||||
|
case "cloud-healthcare-search-dicom-studies":
|
||||||
|
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not read response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode > 299 {
|
||||||
|
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
||||||
|
}
|
||||||
|
if len(respBytes) == 0 {
|
||||||
|
return []interface{}{}, nil
|
||||||
|
}
|
||||||
|
var result []interface{}
|
||||||
|
if err := json.Unmarshal(respBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"text/template"
|
"text/template"
|
||||||
"time"
|
"time"
|
||||||
@@ -36,7 +37,10 @@ import (
|
|||||||
|
|
||||||
const SourceKind string = "cloud-sql-admin"
|
const SourceKind string = "cloud-sql-admin"
|
||||||
|
|
||||||
var targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
var (
|
||||||
|
targetLinkRegex = regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||||
|
backupDRRegex = regexp.MustCompile(`^projects/([^/]+)/locations/([^/]+)/backupVaults/([^/]+)/dataSources/([^/]+)/backups/([^/]+)$`)
|
||||||
|
)
|
||||||
|
|
||||||
// validate interface
|
// validate interface
|
||||||
var _ sources.SourceConfig = Config{}
|
var _ sources.SourceConfig = Config{}
|
||||||
@@ -352,6 +356,70 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
|
||||||
|
backupRun := &sqladmin.BackupRun{}
|
||||||
|
if location != "" {
|
||||||
|
backupRun.Location = location
|
||||||
|
}
|
||||||
|
if backupDescription != "" {
|
||||||
|
backupRun.Description = backupDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := s.GetService(ctx, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error) {
|
||||||
|
request := &sqladmin.InstancesRestoreBackupRequest{}
|
||||||
|
|
||||||
|
// There are 3 scenarios for the backup identifier:
|
||||||
|
// 1. The identifier is an int64 containing the timestamp of the BackupRun.
|
||||||
|
// This is used to restore standard backups, and the RestoreBackupContext
|
||||||
|
// field should be populated with the backup ID and source instance info.
|
||||||
|
// 2. The identifier is a string of the format
|
||||||
|
// 'projects/{project-id}/locations/{location}/backupVaults/{backupvault}/dataSources/{datasource}/backups/{backup-uid}'.
|
||||||
|
// This is used to restore BackupDR backups, and the BackupdrBackup field
|
||||||
|
// should be populated.
|
||||||
|
// 3. The identifer is a string of the format
|
||||||
|
// 'projects/{project-id}/backups/{backup-uid}'. In this case, the Backup
|
||||||
|
// field should be populated.
|
||||||
|
if backupRunID, err := strconv.ParseInt(backupID, 10, 64); err == nil {
|
||||||
|
if sourceProject == "" || targetInstance == "" {
|
||||||
|
return nil, fmt.Errorf("source project and instance are required when restoring via backup ID")
|
||||||
|
}
|
||||||
|
request.RestoreBackupContext = &sqladmin.RestoreBackupContext{
|
||||||
|
Project: sourceProject,
|
||||||
|
InstanceId: sourceInstance,
|
||||||
|
BackupRunId: backupRunID,
|
||||||
|
}
|
||||||
|
} else if backupDRRegex.MatchString(backupID) {
|
||||||
|
request.BackupdrBackup = backupID
|
||||||
|
} else {
|
||||||
|
request.Backup = backupID
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := s.GetService(ctx, string(accessToken))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := service.Instances.RestoreBackup(targetProject, targetInstance, request).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error restoring backup: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
|
||||||
operationType, ok := opResponse["operationType"].(string)
|
operationType, ok := opResponse["operationType"].(string)
|
||||||
if !ok || operationType != "CREATE_DATABASE" {
|
if !ok || operationType != "CREATE_DATABASE" {
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||||
|
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||||
|
"github.com/cenkalti/backoff/v5"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
@@ -121,3 +123,101 @@ func initDataplexConnection(
|
|||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
|
||||||
|
viewMap := map[int]dataplexpb.EntryView{
|
||||||
|
1: dataplexpb.EntryView_BASIC,
|
||||||
|
2: dataplexpb.EntryView_FULL,
|
||||||
|
3: dataplexpb.EntryView_CUSTOM,
|
||||||
|
4: dataplexpb.EntryView_ALL,
|
||||||
|
}
|
||||||
|
req := &dataplexpb.LookupEntryRequest{
|
||||||
|
Name: name,
|
||||||
|
View: viewMap[view],
|
||||||
|
AspectTypes: aspectTypes,
|
||||||
|
Entry: entry,
|
||||||
|
}
|
||||||
|
result, err := s.CatalogClient().LookupEntry(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
|
||||||
|
// Create SearchEntriesRequest with the provided parameters
|
||||||
|
req := &dataplexpb.SearchEntriesRequest{
|
||||||
|
Query: query,
|
||||||
|
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
|
||||||
|
PageSize: int32(pageSize),
|
||||||
|
OrderBy: orderBy,
|
||||||
|
SemanticSearch: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the search using the CatalogClient - this will return an iterator
|
||||||
|
it := s.CatalogClient().SearchEntries(ctx, req)
|
||||||
|
if it == nil {
|
||||||
|
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
|
||||||
|
}
|
||||||
|
return it, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
|
||||||
|
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
|
||||||
|
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through the search results and call GetAspectType for each result using the resource name
|
||||||
|
var results []*dataplexpb.AspectType
|
||||||
|
for {
|
||||||
|
entry, err := it.Next()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
||||||
|
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
||||||
|
getAspectBackOff := backoff.NewExponentialBackOff()
|
||||||
|
|
||||||
|
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
||||||
|
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
||||||
|
Name: resourceName,
|
||||||
|
}
|
||||||
|
|
||||||
|
operation := func() (*dataplexpb.AspectType, error) {
|
||||||
|
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
||||||
|
}
|
||||||
|
return aspectType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry the GetAspectType operation with exponential backoff
|
||||||
|
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, aspectType)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
|
||||||
|
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []*dataplexpb.SearchEntriesResult
|
||||||
|
for {
|
||||||
|
entry, err := it.Next()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
results = append(results, entry)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,10 @@ package firestore
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"cloud.google.com/go/firestore"
|
"cloud.google.com/go/firestore"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
@@ -25,6 +28,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"google.golang.org/api/firebaserules/v1"
|
"google.golang.org/api/firebaserules/v1"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
"google.golang.org/genproto/googleapis/type/latlng"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SourceKind string = "firestore"
|
const SourceKind string = "firestore"
|
||||||
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
|
|||||||
return s.Database
|
return s.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
||||||
|
// This removes type information and returns plain values
|
||||||
|
func FirestoreValueToJSON(value any) any {
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := value.(type) {
|
||||||
|
case time.Time:
|
||||||
|
return v.Format(time.RFC3339Nano)
|
||||||
|
case *latlng.LatLng:
|
||||||
|
return map[string]any{
|
||||||
|
"latitude": v.Latitude,
|
||||||
|
"longitude": v.Longitude,
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
return base64.StdEncoding.EncodeToString(v)
|
||||||
|
case []any:
|
||||||
|
result := make([]any, len(v))
|
||||||
|
for i, item := range v {
|
||||||
|
result[i] = FirestoreValueToJSON(item)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
case map[string]any:
|
||||||
|
result := make(map[string]any)
|
||||||
|
for k, val := range v {
|
||||||
|
result[k] = FirestoreValueToJSON(val)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
case *firestore.DocumentRef:
|
||||||
|
return v.Path
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildQuery constructs the Firestore query from parameters
|
||||||
|
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
|
||||||
|
collection := s.FirestoreClient().Collection(collectionPath)
|
||||||
|
query := collection.Query
|
||||||
|
|
||||||
|
// Process and apply filters if template is provided
|
||||||
|
if filter != nil {
|
||||||
|
query = query.WhereEntity(filter)
|
||||||
|
}
|
||||||
|
if len(selectFields) > 0 {
|
||||||
|
query = query.Select(selectFields...)
|
||||||
|
}
|
||||||
|
if field != "" {
|
||||||
|
query = query.OrderBy(field, direction)
|
||||||
|
}
|
||||||
|
query = query.Limit(limit)
|
||||||
|
|
||||||
|
// Apply analyze options if enabled
|
||||||
|
if analyzeQuery {
|
||||||
|
query = query.WithRunOptions(firestore.ExplainOptions{
|
||||||
|
Analyze: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryResult represents a document result from the query
|
||||||
|
type QueryResult struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
Data map[string]any `json:"data"`
|
||||||
|
CreateTime any `json:"createTime,omitempty"`
|
||||||
|
UpdateTime any `json:"updateTime,omitempty"`
|
||||||
|
ReadTime any `json:"readTime,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryResponse represents the full response including optional metrics
|
||||||
|
type QueryResponse struct {
|
||||||
|
Documents []QueryResult `json:"documents"`
|
||||||
|
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteQuery runs the query and formats the results
|
||||||
|
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
|
||||||
|
docIterator := query.Documents(ctx)
|
||||||
|
docs, err := docIterator.GetAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||||
|
}
|
||||||
|
// Convert results to structured format
|
||||||
|
results := make([]QueryResult, len(docs))
|
||||||
|
for i, doc := range docs {
|
||||||
|
results[i] = QueryResult{
|
||||||
|
ID: doc.Ref.ID,
|
||||||
|
Path: doc.Ref.Path,
|
||||||
|
Data: doc.Data(),
|
||||||
|
CreateTime: doc.CreateTime,
|
||||||
|
UpdateTime: doc.UpdateTime,
|
||||||
|
ReadTime: doc.ReadTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return with explain metrics if requested
|
||||||
|
if analyzeQuery {
|
||||||
|
explainMetrics, err := getExplainMetrics(docIterator)
|
||||||
|
if err == nil && explainMetrics != nil {
|
||||||
|
response := QueryResponse{
|
||||||
|
Documents: results,
|
||||||
|
ExplainMetrics: explainMetrics,
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExplainMetrics extracts explain metrics from the query iterator
|
||||||
|
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
|
||||||
|
explainMetrics, err := docIterator.ExplainMetrics()
|
||||||
|
if err != nil || explainMetrics == nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
metricsData := make(map[string]any)
|
||||||
|
|
||||||
|
// Add plan summary if available
|
||||||
|
if explainMetrics.PlanSummary != nil {
|
||||||
|
planSummary := make(map[string]any)
|
||||||
|
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||||
|
metricsData["planSummary"] = planSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add execution stats if available
|
||||||
|
if explainMetrics.ExecutionStats != nil {
|
||||||
|
executionStats := make(map[string]any)
|
||||||
|
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||||
|
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||||
|
|
||||||
|
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||||
|
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||||
|
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||||
|
}
|
||||||
|
|
||||||
|
metricsData["executionStats"] = executionStats
|
||||||
|
}
|
||||||
|
|
||||||
|
return metricsData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||||
|
// Create document references from paths
|
||||||
|
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
|
||||||
|
for i, path := range documentPaths {
|
||||||
|
docRefs[i] = s.FirestoreClient().Doc(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all documents
|
||||||
|
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert snapshots to response data
|
||||||
|
results := make([]any, len(snapshots))
|
||||||
|
for i, snapshot := range snapshots {
|
||||||
|
docData := make(map[string]any)
|
||||||
|
docData["path"] = documentPaths[i]
|
||||||
|
docData["exists"] = snapshot.Exists()
|
||||||
|
|
||||||
|
if snapshot.Exists() {
|
||||||
|
docData["data"] = snapshot.Data()
|
||||||
|
docData["createTime"] = snapshot.CreateTime
|
||||||
|
docData["updateTime"] = snapshot.UpdateTime
|
||||||
|
docData["readTime"] = snapshot.ReadTime
|
||||||
|
}
|
||||||
|
|
||||||
|
results[i] = docData
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
|
||||||
|
// Get the collection reference
|
||||||
|
collection := s.FirestoreClient().Collection(collectionPath)
|
||||||
|
|
||||||
|
// Add the document to the collection
|
||||||
|
docRef, writeResult, err := collection.Add(ctx, documentData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add document: %w", err)
|
||||||
|
}
|
||||||
|
// Build the response
|
||||||
|
response := map[string]any{
|
||||||
|
"documentPath": docRef.Path,
|
||||||
|
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||||
|
}
|
||||||
|
// Add document data if requested
|
||||||
|
if returnData {
|
||||||
|
// Fetch the updated document to return the current state
|
||||||
|
snapshot, err := docRef.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||||
|
}
|
||||||
|
// Convert the document data back to simple JSON format
|
||||||
|
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||||
|
response["documentData"] = simplifiedData
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
|
||||||
|
// Get the document reference
|
||||||
|
docRef := s.FirestoreClient().Doc(documentPath)
|
||||||
|
|
||||||
|
// Prepare update data
|
||||||
|
var writeResult *firestore.WriteResult
|
||||||
|
var writeErr error
|
||||||
|
|
||||||
|
if len(updates) > 0 {
|
||||||
|
writeResult, writeErr = docRef.Update(ctx, updates)
|
||||||
|
} else {
|
||||||
|
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
|
||||||
|
}
|
||||||
|
|
||||||
|
if writeErr != nil {
|
||||||
|
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the response
|
||||||
|
response := map[string]any{
|
||||||
|
"documentPath": docRef.Path,
|
||||||
|
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add document data if requested
|
||||||
|
if returnData {
|
||||||
|
// Fetch the updated document to return the current state
|
||||||
|
snapshot, err := docRef.Get(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
||||||
|
}
|
||||||
|
// Convert the document data to simple JSON format
|
||||||
|
simplifiedData := FirestoreValueToJSON(snapshot.Data())
|
||||||
|
response["documentData"] = simplifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
|
||||||
|
// Create a BulkWriter to handle multiple deletions efficiently
|
||||||
|
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
|
||||||
|
|
||||||
|
// Keep track of jobs for each document
|
||||||
|
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
|
||||||
|
|
||||||
|
// Add all delete operations to the BulkWriter
|
||||||
|
for i, path := range documentPaths {
|
||||||
|
docRef := s.FirestoreClient().Doc(path)
|
||||||
|
job, err := bulkWriter.Delete(docRef)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||||
|
}
|
||||||
|
jobs[i] = job
|
||||||
|
}
|
||||||
|
|
||||||
|
// End the BulkWriter to execute all operations
|
||||||
|
bulkWriter.End()
|
||||||
|
|
||||||
|
// Collect results
|
||||||
|
results := make([]any, len(documentPaths))
|
||||||
|
for i, job := range jobs {
|
||||||
|
docData := make(map[string]any)
|
||||||
|
docData["path"] = documentPaths[i]
|
||||||
|
|
||||||
|
// Wait for the job to complete and get the result
|
||||||
|
_, err := job.Results()
|
||||||
|
if err != nil {
|
||||||
|
docData["success"] = false
|
||||||
|
docData["error"] = err.Error()
|
||||||
|
} else {
|
||||||
|
docData["success"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
results[i] = docData
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
|
||||||
|
var collectionRefs []*firestore.CollectionRef
|
||||||
|
var err error
|
||||||
|
if parentPath != "" {
|
||||||
|
// List subcollections of the specified document
|
||||||
|
docRef := s.FirestoreClient().Doc(parentPath)
|
||||||
|
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// List root collections
|
||||||
|
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert collection references to response data
|
||||||
|
results := make([]any, len(collectionRefs))
|
||||||
|
for i, collRef := range collectionRefs {
|
||||||
|
collData := make(map[string]any)
|
||||||
|
collData["id"] = collRef.ID
|
||||||
|
collData["path"] = collRef.Path
|
||||||
|
|
||||||
|
// If this is a subcollection, include parent information
|
||||||
|
if collRef.Parent != nil {
|
||||||
|
collData["parent"] = collRef.Parent.Path
|
||||||
|
}
|
||||||
|
results[i] = collData
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetRules(ctx context.Context) (any, error) {
|
||||||
|
// Get the latest release for Firestore
|
||||||
|
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
|
||||||
|
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if release.RulesetName == "" {
|
||||||
|
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the ruleset content
|
||||||
|
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||||
|
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||||
|
}
|
||||||
|
|
||||||
|
return ruleset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SourcePosition represents the location of an issue in the source
|
||||||
|
type SourcePosition struct {
|
||||||
|
FileName string `json:"fileName,omitempty"`
|
||||||
|
Line int64 `json:"line"` // 1-based
|
||||||
|
Column int64 `json:"column"` // 1-based
|
||||||
|
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||||
|
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue represents a validation issue in the rules
|
||||||
|
type Issue struct {
|
||||||
|
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Severity string `json:"severity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationResult represents the result of rules validation
|
||||||
|
type ValidationResult struct {
|
||||||
|
Valid bool `json:"valid"`
|
||||||
|
IssueCount int `json:"issueCount"`
|
||||||
|
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||||
|
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
|
||||||
|
// Create test request
|
||||||
|
testRequest := &firebaserules.TestRulesetRequest{
|
||||||
|
Source: &firebaserules.Source{
|
||||||
|
Files: []*firebaserules.File{
|
||||||
|
{
|
||||||
|
Name: "firestore.rules",
|
||||||
|
Content: sourceParam,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// We don't need test cases for validation only
|
||||||
|
TestSuite: &firebaserules.TestSuite{
|
||||||
|
TestCases: []*firebaserules.TestCase{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Call the test API
|
||||||
|
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
|
||||||
|
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the response
|
||||||
|
if len(response.Issues) == 0 {
|
||||||
|
return ValidationResult{
|
||||||
|
Valid: true,
|
||||||
|
IssueCount: 0,
|
||||||
|
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert issues to our format
|
||||||
|
issues := make([]Issue, len(response.Issues))
|
||||||
|
for i, issue := range response.Issues {
|
||||||
|
issues[i] = Issue{
|
||||||
|
Description: issue.Description,
|
||||||
|
Severity: issue.Severity,
|
||||||
|
SourcePosition: SourcePosition{
|
||||||
|
FileName: issue.SourcePosition.FileName,
|
||||||
|
Line: issue.SourcePosition.Line,
|
||||||
|
Column: issue.SourcePosition.Column,
|
||||||
|
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||||
|
EndOffset: issue.SourcePosition.EndOffset,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format issues
|
||||||
|
sourceLines := strings.Split(sourceParam, "\n")
|
||||||
|
var formattedOutput []string
|
||||||
|
|
||||||
|
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||||
|
|
||||||
|
for _, issue := range issues {
|
||||||
|
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||||
|
issue.Severity,
|
||||||
|
issue.Description,
|
||||||
|
issue.SourcePosition.Line,
|
||||||
|
issue.SourcePosition.Column)
|
||||||
|
|
||||||
|
if issue.SourcePosition.Line > 0 {
|
||||||
|
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||||
|
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||||
|
errorLine := sourceLines[lineIndex]
|
||||||
|
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||||
|
|
||||||
|
// Add carets if we have column and offset information
|
||||||
|
if issue.SourcePosition.Column > 0 &&
|
||||||
|
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||||
|
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||||
|
|
||||||
|
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||||
|
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||||
|
|
||||||
|
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||||
|
padding := strings.Repeat(" ", startColumn)
|
||||||
|
carets := strings.Repeat("^", errorTokenLength)
|
||||||
|
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
issueString += "\n```"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formattedOutput = append(formattedOutput, issueString)
|
||||||
|
}
|
||||||
|
|
||||||
|
formattedIssues := strings.Join(formattedOutput, "\n\n")
|
||||||
|
|
||||||
|
return ValidationResult{
|
||||||
|
Valid: false,
|
||||||
|
IssueCount: len(issues),
|
||||||
|
FormattedIssues: formattedIssues,
|
||||||
|
RawIssues: issues,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initFirestoreConnection(
|
func initFirestoreConnection(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
tracer trace.Tracer,
|
tracer trace.Tracer,
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ package firestore_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
||||||
|
// Test round-trip conversion
|
||||||
|
original := map[string]any{
|
||||||
|
"name": "Test",
|
||||||
|
"count": int64(42),
|
||||||
|
"price": 19.99,
|
||||||
|
"active": true,
|
||||||
|
"tags": []any{"tag1", "tag2"},
|
||||||
|
"metadata": map[string]any{
|
||||||
|
"created": time.Now(),
|
||||||
|
},
|
||||||
|
"nullField": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to JSON representation
|
||||||
|
jsonRepresentation := firestore.FirestoreValueToJSON(original)
|
||||||
|
|
||||||
|
// Verify types are simplified
|
||||||
|
jsonMap, ok := jsonRepresentation.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time should be converted to string
|
||||||
|
metadata, ok := jsonMap["metadata"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
||||||
|
}
|
||||||
|
_, ok = metadata["created"].(string)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("created should be a string, got %T", metadata["created"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ package http
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
@@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string {
|
|||||||
func (s *Source) Client() *http.Client {
|
func (s *Source) Client() *http.Client {
|
||||||
return s.client
|
return s.client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) RunRequest(req *http.Request) (any, error) {
|
||||||
|
// Make request and fetch response
|
||||||
|
resp, err := s.Client().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var body []byte
|
||||||
|
body, err = io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||||
|
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var data any
|
||||||
|
if err = json.Unmarshal(body, &data); err != nil {
|
||||||
|
// if unable to unmarshal data, return result as string.
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ package looker
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -208,6 +210,49 @@ func (s *Source) LookerSessionLength() int64 {
|
|||||||
return s.SessionLength
|
return s.SessionLength
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Make types for RoundTripper
|
||||||
|
type transportWithAuthHeader struct {
|
||||||
|
Base http.RoundTripper
|
||||||
|
AuthToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
req.Header.Set("x-looker-appid", "go-sdk")
|
||||||
|
req.Header.Set("Authorization", t.AuthToken)
|
||||||
|
return t.Base.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetLookerSDK(accessToken string) (*v4.LookerSDK, error) {
|
||||||
|
if s.UseClientAuthorization() {
|
||||||
|
if accessToken == "" {
|
||||||
|
return nil, fmt.Errorf("no access token supplied with request")
|
||||||
|
}
|
||||||
|
|
||||||
|
session := rtl.NewAuthSession(*s.LookerApiSettings())
|
||||||
|
// Configure base transport with TLS
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
InsecureSkipVerify: !s.LookerApiSettings().VerifySsl,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build transport for end user token
|
||||||
|
session.Client = http.Client{
|
||||||
|
Transport: &transportWithAuthHeader{
|
||||||
|
Base: transport,
|
||||||
|
AuthToken: accessToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// return SDK with new Transport
|
||||||
|
return v4.NewLookerSDK(session), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.LookerClient() == nil {
|
||||||
|
return nil, fmt.Errorf("client id or client secret not valid")
|
||||||
|
}
|
||||||
|
return s.LookerClient(), nil
|
||||||
|
}
|
||||||
|
|
||||||
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
|
||||||
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -16,11 +16,14 @@ package mongodb
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
|
"go.mongodb.org/mongo-driver/bson"
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
"go.mongodb.org/mongo-driver/mongo/options"
|
"go.mongodb.org/mongo-driver/mongo/options"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
@@ -93,6 +96,201 @@ func (s *Source) MongoClient() *mongo.Client {
|
|||||||
return s.Client
|
return s.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
|
||||||
|
var data = []any{}
|
||||||
|
err := cur.All(ctx, &data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var final []any
|
||||||
|
for _, item := range data {
|
||||||
|
tmp, _ := bson.MarshalExtJSON(item, false, false)
|
||||||
|
var tmp2 any
|
||||||
|
err = json.Unmarshal(tmp, &tmp2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
final = append(final, tmp2)
|
||||||
|
}
|
||||||
|
return final, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
|
||||||
|
var pipeline = []bson.M{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if readOnly {
|
||||||
|
//fail if we do a merge or an out
|
||||||
|
for _, stage := range pipeline {
|
||||||
|
for key := range stage {
|
||||||
|
if key == "$merge" || key == "$out" {
|
||||||
|
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer cur.Close(ctx)
|
||||||
|
res, err := parseData(ctx, cur)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if res == nil {
|
||||||
|
return []any{}, nil
|
||||||
|
}
|
||||||
|
return res, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer cur.Close(ctx)
|
||||||
|
return parseData(ctx, cur)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
|
||||||
|
if res.Err() != nil {
|
||||||
|
return nil, res.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
var data any
|
||||||
|
err = res.Decode(&data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var final []any
|
||||||
|
tmp, _ := bson.MarshalExtJSON(data, false, false)
|
||||||
|
var tmp2 any
|
||||||
|
err = json.Unmarshal(tmp, &tmp2)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
final = append(final, tmp2)
|
||||||
|
|
||||||
|
return final, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
|
||||||
|
var data = []any{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.InsertedIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
|
||||||
|
var data any
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.InsertedID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||||
|
}
|
||||||
|
var update = bson.D{}
|
||||||
|
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||||
|
}
|
||||||
|
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
|
||||||
|
}
|
||||||
|
var update = bson.D{}
|
||||||
|
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error updating collection: %w", err)
|
||||||
|
}
|
||||||
|
return res.ModifiedCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.DeletedCount == 0 {
|
||||||
|
return nil, errors.New("no document found")
|
||||||
|
}
|
||||||
|
return res.DeletedCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
|
||||||
|
var filter = bson.D{}
|
||||||
|
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res.DeletedCount, nil
|
||||||
|
}
|
||||||
|
|
||||||
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
|
||||||
// Start a tracing span
|
// Start a tracing span
|
||||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||||
|
|||||||
@@ -16,15 +16,21 @@ package serverlessspark
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
|
||||||
|
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||||
longrunning "cloud.google.com/go/longrunning/autogen"
|
longrunning "cloud.google.com/go/longrunning/autogen"
|
||||||
|
"cloud.google.com/go/longrunning/autogen/longrunningpb"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
"google.golang.org/api/iterator"
|
||||||
"google.golang.org/api/option"
|
"google.golang.org/api/option"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SourceKind string = "serverless-spark"
|
const SourceKind string = "serverless-spark"
|
||||||
@@ -121,3 +127,168 @@ func (s *Source) Close() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
|
||||||
|
req := &longrunningpb.CancelOperationRequest{
|
||||||
|
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
|
||||||
|
}
|
||||||
|
client, err := s.GetOperationsClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get operations client: %w", err)
|
||||||
|
}
|
||||||
|
err = client.CancelOperation(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to cancel operation: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Cancelled [%s].", operation), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
|
||||||
|
req := &dataprocpb.CreateBatchRequest{
|
||||||
|
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
|
||||||
|
Batch: batch,
|
||||||
|
}
|
||||||
|
|
||||||
|
client := s.GetBatchControllerClient()
|
||||||
|
op, err := client.CreateBatch(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create batch: %w", err)
|
||||||
|
}
|
||||||
|
meta, err := op.Metadata()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
|
||||||
|
}
|
||||||
|
consoleUrl := BatchConsoleURL(projectID, location, batchID)
|
||||||
|
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
|
||||||
|
|
||||||
|
wrappedResult := map[string]any{
|
||||||
|
"opMetadata": meta,
|
||||||
|
"consoleUrl": consoleUrl,
|
||||||
|
"logsUrl": logsUrl,
|
||||||
|
}
|
||||||
|
return wrappedResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListBatchesResponse is the response from the list batches API.
|
||||||
|
type ListBatchesResponse struct {
|
||||||
|
Batches []Batch `json:"batches"`
|
||||||
|
NextPageToken string `json:"nextPageToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Batch represents a single batch job.
|
||||||
|
type Batch struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
UUID string `json:"uuid"`
|
||||||
|
State string `json:"state"`
|
||||||
|
Creator string `json:"creator"`
|
||||||
|
CreateTime string `json:"createTime"`
|
||||||
|
Operation string `json:"operation"`
|
||||||
|
ConsoleURL string `json:"consoleUrl"`
|
||||||
|
LogsURL string `json:"logsUrl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
|
||||||
|
client := s.GetBatchControllerClient()
|
||||||
|
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
|
||||||
|
req := &dataprocpb.ListBatchesRequest{
|
||||||
|
Parent: parent,
|
||||||
|
OrderBy: "create_time desc",
|
||||||
|
}
|
||||||
|
|
||||||
|
if ps != nil {
|
||||||
|
req.PageSize = int32(*ps)
|
||||||
|
}
|
||||||
|
if pt != "" {
|
||||||
|
req.PageToken = pt
|
||||||
|
}
|
||||||
|
if filter != "" {
|
||||||
|
req.Filter = filter
|
||||||
|
}
|
||||||
|
|
||||||
|
it := client.ListBatches(ctx, req)
|
||||||
|
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
|
||||||
|
|
||||||
|
var batchPbs []*dataprocpb.Batch
|
||||||
|
nextPageToken, err := pager.NextPage(&batchPbs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list batches: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
batches, err := ToBatches(batchPbs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
|
||||||
|
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
|
||||||
|
batches := make([]Batch, 0, len(batchPbs))
|
||||||
|
for _, batchPb := range batchPbs {
|
||||||
|
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||||
|
}
|
||||||
|
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||||
|
}
|
||||||
|
batch := Batch{
|
||||||
|
Name: batchPb.Name,
|
||||||
|
UUID: batchPb.Uuid,
|
||||||
|
State: batchPb.State.Enum().String(),
|
||||||
|
Creator: batchPb.Creator,
|
||||||
|
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
|
||||||
|
Operation: batchPb.Operation,
|
||||||
|
ConsoleURL: consoleUrl,
|
||||||
|
LogsURL: logsUrl,
|
||||||
|
}
|
||||||
|
batches = append(batches, batch)
|
||||||
|
}
|
||||||
|
return batches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
|
||||||
|
client := s.GetBatchControllerClient()
|
||||||
|
req := &dataprocpb.GetBatchRequest{
|
||||||
|
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
|
||||||
|
}
|
||||||
|
|
||||||
|
batchPb, err := client.GetBatch(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get batch: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := protojson.Marshal(batchPb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]any
|
||||||
|
if err := json.Unmarshal(jsonBytes, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error generating console url: %v", err)
|
||||||
|
}
|
||||||
|
logsUrl, err := BatchLogsURLFromProto(batchPb)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error generating logs url: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedResult := map[string]any{
|
||||||
|
"consoleUrl": consoleUrl,
|
||||||
|
"logsUrl": logsUrl,
|
||||||
|
"batch": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrappedResult, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
// Copyright 2025 Google LLC
|
// Copyright 2026 Google LLC
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package common
|
package serverlessspark
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -23,13 +23,13 @@ import (
|
|||||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
logTimeBufferBefore = 1 * time.Minute
|
logTimeBufferBefore = 1 * time.Minute
|
||||||
logTimeBufferAfter = 10 * time.Minute
|
logTimeBufferAfter = 10 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
|
|
||||||
|
|
||||||
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
|
||||||
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
|
||||||
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
matches := batchFullNameRegex.FindStringSubmatch(batchName)
|
||||||
@@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
|
|||||||
return matches[1], matches[2], matches[3], nil
|
return matches[1], matches[2], matches[3], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
|
||||||
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
|
||||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return BatchConsoleURL(projectID, location, batchID), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
|
||||||
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
|
||||||
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
createTime := batchPb.GetCreateTime().AsTime()
|
|
||||||
stateTime := batchPb.GetStateTime().AsTime()
|
|
||||||
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||||
func BatchConsoleURL(projectID, location, batchID string) string {
|
func BatchConsoleURL(projectID, location, batchID string) string {
|
||||||
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
|
||||||
@@ -89,3 +69,23 @@ resource.labels.batch_id="%s"`
|
|||||||
|
|
||||||
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
|
||||||
|
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||||
|
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return BatchConsoleURL(projectID, location, batchID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
|
||||||
|
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
|
||||||
|
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
createTime := batchPb.GetCreateTime().AsTime()
|
||||||
|
stateTime := batchPb.GetStateTime().AsTime()
|
||||||
|
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
|
||||||
|
}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
// Copyright 2025 Google LLC
|
// Copyright 2026 Google LLC
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
// You may obtain a copy of the License at
|
// You may obtain a copy of the License at
|
||||||
//
|
//
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
//
|
//
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
@@ -12,19 +12,20 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
package common
|
package serverlessspark_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractBatchDetails_Success(t *testing.T) {
|
func TestExtractBatchDetails_Success(t *testing.T) {
|
||||||
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
|
||||||
projectID, location, batchID, err := ExtractBatchDetails(batchName)
|
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
|
||||||
return
|
return
|
||||||
@@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractBatchDetails_Failure(t *testing.T) {
|
func TestExtractBatchDetails_Failure(t *testing.T) {
|
||||||
batchName := "invalid-name"
|
batchName := "invalid-name"
|
||||||
_, _, _, err := ExtractBatchDetails(batchName)
|
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
|
||||||
wantErr := "failed to parse batch name: invalid-name"
|
wantErr := "failed to parse batch name: invalid-name"
|
||||||
if err == nil || err.Error() != wantErr {
|
if err == nil || err.Error() != wantErr {
|
||||||
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
|
||||||
@@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBatchConsoleURL(t *testing.T) {
|
func TestBatchConsoleURL(t *testing.T) {
|
||||||
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
|
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
|
||||||
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
|
||||||
if got != want {
|
if got != want {
|
||||||
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
|
||||||
@@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) {
|
|||||||
func TestBatchLogsURL(t *testing.T) {
|
func TestBatchLogsURL(t *testing.T) {
|
||||||
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
|
||||||
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
|
||||||
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
|
||||||
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
|
||||||
"resource.type%3D%22cloud_dataproc_batch%22" +
|
"resource.type%3D%22cloud_dataproc_batch%22" +
|
||||||
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
"%0Aresource.labels.project_id%3D%22my-project%22" +
|
||||||
@@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
|
|||||||
batchPb := &dataprocpb.Batch{
|
batchPb := &dataprocpb.Batch{
|
||||||
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
Name: "projects/my-project/locations/us-central1/batches/my-batch",
|
||||||
}
|
}
|
||||||
got, err := BatchConsoleURLFromProto(batchPb)
|
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
|
|||||||
CreateTime: timestamppb.New(createTime),
|
CreateTime: timestamppb.New(createTime),
|
||||||
StateTime: timestamppb.New(stateTime),
|
StateTime: timestamppb.New(stateTime),
|
||||||
}
|
}
|
||||||
got, err := BatchLogsURLFromProto(batchPb)
|
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -28,6 +28,21 @@ import (
|
|||||||
|
|
||||||
const kind string = "cloud-gemini-data-analytics-query"
|
const kind string = "cloud-gemini-data-analytics-query"
|
||||||
|
|
||||||
|
// Guidance is the tool guidance string.
|
||||||
|
const Guidance = `Tool guidance:
|
||||||
|
Inputs:
|
||||||
|
1. query: A natural language formulation of a database query.
|
||||||
|
Outputs: (all optional)
|
||||||
|
1. disambiguation_question: Clarification questions or comments where the tool needs the users' input.
|
||||||
|
2. generated_query: The generated query for the user query.
|
||||||
|
3. intent_explanation: An explanation for why the tool produced ` + "`generated_query`" + `.
|
||||||
|
4. query_result: The result of executing ` + "`generated_query`" + `.
|
||||||
|
5. natural_language_answer: The natural language answer that summarizes the ` + "`query`" + ` and ` + "`query_result`" + `.
|
||||||
|
|
||||||
|
Usage guidance:
|
||||||
|
1. If ` + "`disambiguation_question`" + ` is produced, then solicit the needed inputs from the user and try the tool with a new ` + "`query`" + ` that has the needed clarification.
|
||||||
|
2. If ` + "`natural_language_answer`" + ` is produced, use ` + "`intent_explanation`" + ` and ` + "`generated_query`" + ` to see if you need to clarify any assumptions for the user.`
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if !tools.Register(kind, newConfig) {
|
if !tools.Register(kind, newConfig) {
|
||||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||||
@@ -68,11 +83,18 @@ func (cfg Config) ToolConfigKind() string {
|
|||||||
|
|
||||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||||
// Define the parameters for the Gemini Data Analytics Query API
|
// Define the parameters for the Gemini Data Analytics Query API
|
||||||
// The prompt is the only input parameter.
|
// The query is the only input parameter.
|
||||||
allParameters := parameters.Parameters{
|
allParameters := parameters.Parameters{
|
||||||
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
|
parameters.NewStringParameterWithRequired("query", "A natural language formulation of a database query.", true),
|
||||||
}
|
}
|
||||||
|
// The input and outputs are for tool guidance, usage guidance is for multi-turn interaction.
|
||||||
|
guidance := Guidance
|
||||||
|
|
||||||
|
if cfg.Description != "" {
|
||||||
|
cfg.Description += "\n\n" + guidance
|
||||||
|
} else {
|
||||||
|
cfg.Description = guidance
|
||||||
|
}
|
||||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||||
|
|
||||||
return Tool{
|
return Tool{
|
||||||
@@ -105,9 +127,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
prompt, ok := paramsMap["prompt"].(string)
|
query, ok := paramsMap["query"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("prompt parameter not found or not a string")
|
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the access token if provided
|
// Parse the access token if provided
|
||||||
@@ -125,7 +147,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
payload := &QueryDataRequest{
|
payload := &QueryDataRequest{
|
||||||
Parent: payloadParent,
|
Parent: payloadParent,
|
||||||
Prompt: prompt,
|
Prompt: query,
|
||||||
Context: t.Context,
|
Context: t.Context,
|
||||||
GenerationOptions: t.GenerationOptions,
|
GenerationOptions: t.GenerationOptions,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -328,9 +328,9 @@ func TestInvoke(t *testing.T) {
|
|||||||
t.Fatalf("failed to initialize tool: %v", err)
|
t.Fatalf("failed to initialize tool: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare parameters for invocation - ONLY prompt
|
// Prepare parameters for invocation - ONLY query
|
||||||
params := parameters.ParamValues{
|
params := parameters.ParamValues{
|
||||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||||
}
|
}
|
||||||
|
|
||||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||||
|
|||||||
@@ -16,22 +16,13 @@ package fhirfetchpage
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-fhir-fetch-page"
|
const kind string = "cloud-healthcare-fhir-fetch-page"
|
||||||
@@ -54,13 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
FHIRFetchPage(context.Context, string, string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -118,48 +104,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
var httpClient *http.Client
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
|
|
||||||
httpClient = oauth2.NewClient(ctx, ts)
|
|
||||||
} else {
|
|
||||||
// The source.Service() object holds a client with the default credentials.
|
|
||||||
// However, the client is not exported, so we have to create a new one.
|
|
||||||
var err error
|
|
||||||
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create default http client: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create http request: %w", err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
|
return source.FHIRFetchPage(ctx, url, tokenStr)
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
var jsonMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
|
||||||
}
|
|
||||||
return jsonMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,20 +16,16 @@ package fhirpatienteverything
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-fhir-patient-everything"
|
const kind string = "cloud-healthcare-fhir-patient-everything"
|
||||||
@@ -54,13 +50,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
AllowedFHIRStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
FHIRPatientEverything(string, string, string, []googleapi.CallOption) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -139,20 +131,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := source.Service()
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
// Initialize new service if using user OAuth token
|
if err != nil {
|
||||||
if source.UseClientAuthorization() {
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
|
|
||||||
var opts []googleapi.CallOption
|
var opts []googleapi.CallOption
|
||||||
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
if val, ok := params.AsMap()[typeFilterKey]; ok {
|
||||||
types, ok := val.([]any)
|
types, ok := val.([]any)
|
||||||
@@ -176,25 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
|
||||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("patient-everything: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
var jsonMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
|
||||||
}
|
|
||||||
return jsonMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,20 +16,16 @@ package fhirpatientsearch
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-fhir-patient-search"
|
const kind string = "cloud-healthcare-fhir-patient-search"
|
||||||
@@ -70,13 +66,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
AllowedFHIRStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
FHIRPatientSearch(string, string, []googleapi.CallOption) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -169,17 +161,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
svc := source.Service()
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
// Initialize new service if using user OAuth token
|
if err != nil {
|
||||||
if source.UseClientAuthorization() {
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var summary bool
|
var summary bool
|
||||||
@@ -248,26 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if summary {
|
if summary {
|
||||||
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
|
||||||
}
|
}
|
||||||
|
return source.FHIRPatientSearch(storeID, tokenStr, opts)
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to search patient resources: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
var jsonMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
|
||||||
}
|
|
||||||
return jsonMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
"google.golang.org/api/healthcare/v1"
|
||||||
@@ -44,12 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetDataset(string) (*healthcare.Dataset, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -100,27 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
|
||||||
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
return dataset, nil
|
return source.GetDataset(tokenStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetDICOMStore(string, string) (*healthcare.DicomStore, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
return store, nil
|
return source.GetDICOMStore(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetDICOMStoreMetrics(string, string) (*healthcare.DicomStoreMetrics, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
return store, nil
|
return source.GetDICOMStoreMetrics(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,18 +16,14 @@ package getfhirresource
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-get-fhir-resource"
|
const kind string = "cloud-healthcare-get-fhir-resource"
|
||||||
@@ -51,13 +47,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
AllowedFHIRStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetFHIRResource(string, string, string, string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -134,46 +126,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
resID, ok := params.AsMap()[idKey].(string)
|
resID, ok := params.AsMap()[idKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
|
|
||||||
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
|
|
||||||
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
|
|
||||||
resp, err := call.Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
var jsonMap map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
|
|
||||||
}
|
|
||||||
return jsonMap, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
AllowedFHIRStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetFHIRStore(string, string) (*healthcare.FhirStore, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
return store, nil
|
return source.GetFHIRStore(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
AllowedFHIRStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
GetFHIRStoreMetrics(string, string) (*healthcare.FhirStoreMetrics, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
return store, nil
|
return source.GetFHIRStoreMetrics(storeID, tokenStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,12 +17,10 @@ package listdicomstores
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
"google.golang.org/api/healthcare/v1"
|
||||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
|
||||||
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
var filtered []*healthcare.DicomStore
|
return source.ListDICOMStores(tokenStr)
|
||||||
for _, store := range stores.DicomStores {
|
|
||||||
if len(source.AllowedDICOMStores()) == 0 {
|
|
||||||
filtered = append(filtered, store)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(store.Name) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts := strings.Split(store.Name, "/")
|
|
||||||
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
|
|
||||||
filtered = append(filtered, store)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -17,12 +17,10 @@ package listfhirstores
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
"google.golang.org/api/healthcare/v1"
|
||||||
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedFHIRStores() map[string]struct{}
|
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
ListFHIRStores(string) ([]*healthcare.FhirStore, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
|
||||||
|
|
||||||
// Initialize new service if using user OAuth token
|
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
|
|
||||||
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
}
|
}
|
||||||
var filtered []*healthcare.FhirStore
|
return source.ListFHIRStores(tokenStr)
|
||||||
for _, store := range stores.FhirStores {
|
|
||||||
if len(source.AllowedFHIRStores()) == 0 {
|
|
||||||
filtered = append(filtered, store)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if len(store.Name) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts := strings.Split(store.Name, "/")
|
|
||||||
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
|
|
||||||
filtered = append(filtered, store)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,18 +16,14 @@ package retrieverendereddicominstance
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
|
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
|
||||||
@@ -53,13 +49,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
RetrieveRenderedDICOMInstance(string, string, string, string, int, string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -135,20 +127,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
if err != nil {
|
||||||
// Initialize new service if using user OAuth token
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
|
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
|
||||||
@@ -165,25 +147,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
|
||||||
}
|
}
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr)
|
||||||
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
|
|
||||||
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
|
|
||||||
call.Header().Set("Accept", "image/jpeg")
|
|
||||||
resp, err := call.Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
base64String := base64.StdEncoding.EncodeToString(respBytes)
|
|
||||||
return base64String, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,20 +16,16 @@ package searchdicominstances
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/googleapi"
|
"google.golang.org/api/googleapi"
|
||||||
"google.golang.org/api/healthcare/v1"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-search-dicom-instances"
|
const kind string = "cloud-healthcare-search-dicom-instances"
|
||||||
@@ -60,13 +56,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -144,23 +136,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
if err != nil {
|
||||||
// Initialize new service if using user OAuth token
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||||
@@ -191,29 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
if len(respBytes) == 0 {
|
|
||||||
return []interface{}{}, nil
|
|
||||||
}
|
|
||||||
var result []interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,18 +16,15 @@ package searchdicomseries
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
"google.golang.org/api/googleapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-search-dicom-series"
|
const kind string = "cloud-healthcare-search-dicom-series"
|
||||||
@@ -57,13 +54,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -145,18 +138,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
if err != nil {
|
||||||
// Initialize new service if using user OAuth token
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
|
||||||
@@ -174,29 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
|
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
|
||||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to search dicom series: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
if len(respBytes) == 0 {
|
|
||||||
return []interface{}{}, nil
|
|
||||||
}
|
|
||||||
var result []interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,18 +16,15 @@ package searchdicomstudies
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
"google.golang.org/api/healthcare/v1"
|
"google.golang.org/api/googleapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-healthcare-search-dicom-studies"
|
const kind string = "cloud-healthcare-search-dicom-studies"
|
||||||
@@ -55,13 +52,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
Project() string
|
|
||||||
Region() string
|
|
||||||
DatasetID() string
|
|
||||||
AllowedDICOMStores() map[string]struct{}
|
AllowedDICOMStores() map[string]struct{}
|
||||||
Service() *healthcare.Service
|
|
||||||
ServiceCreator() healthcareds.HealthcareServiceCreator
|
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
|
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -136,51 +129,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
tokenStr, err := accessToken.ParseBearerToken()
|
||||||
svc := source.Service()
|
if err != nil {
|
||||||
// Initialize new service if using user OAuth token
|
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||||
if source.UseClientAuthorization() {
|
|
||||||
tokenStr, err := accessToken.ParseBearerToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
|
||||||
}
|
|
||||||
svc, err = source.ServiceCreator()(tokenStr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
|
dicomWebPath := "studies"
|
||||||
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
|
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBytes, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not read response: %w", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
|
|
||||||
}
|
|
||||||
if len(respBytes) == 0 {
|
|
||||||
return []interface{}{}, nil
|
|
||||||
}
|
|
||||||
var result []interface{}
|
|
||||||
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -0,0 +1,180 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlcreatebackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"google.golang.org/api/sqladmin/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kind string = "cloud-sql-create-backup"
|
||||||
|
|
||||||
|
var _ tools.ToolConfig = Config{}
|
||||||
|
|
||||||
|
type compatibleSource interface {
|
||||||
|
GetDefaultProject() string
|
||||||
|
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||||
|
UseClientAuthorization() bool
|
||||||
|
InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config defines the configuration for the create-backup tool.
|
||||||
|
type Config struct {
|
||||||
|
Name string `yaml:"name" validate:"required"`
|
||||||
|
Kind string `yaml:"kind" validate:"required"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
Source string `yaml:"source" validate:"required"`
|
||||||
|
AuthRequired []string `yaml:"authRequired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if !tools.Register(kind, newConfig) {
|
||||||
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||||
|
actual := Config{Name: name}
|
||||||
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolConfigKind returns the kind of the tool.
|
||||||
|
func (cfg Config) ToolConfigKind() string {
|
||||||
|
return kind
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize initializes the tool from the configuration.
|
||||||
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||||
|
rawS, ok := srcs[cfg.Source]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||||
|
}
|
||||||
|
s, ok := rawS.(compatibleSource)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
project := s.GetDefaultProject()
|
||||||
|
var projectParam parameters.Parameter
|
||||||
|
if project != "" {
|
||||||
|
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||||
|
} else {
|
||||||
|
projectParam = parameters.NewStringParameter("project", "The project ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
allParameters := parameters.Parameters{
|
||||||
|
projectParam,
|
||||||
|
parameters.NewStringParameter("instance", "Cloud SQL instance ID. This does not include the project ID."),
|
||||||
|
// Location and backup_description are optional.
|
||||||
|
parameters.NewStringParameterWithRequired("location", "Location of the backup run.", false),
|
||||||
|
parameters.NewStringParameterWithRequired("backup_description", "The description of this backup run.", false),
|
||||||
|
}
|
||||||
|
paramManifest := allParameters.Manifest()
|
||||||
|
|
||||||
|
description := cfg.Description
|
||||||
|
if description == "" {
|
||||||
|
description = "Creates a backup on a Cloud SQL instance."
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||||
|
|
||||||
|
return Tool{
|
||||||
|
Config: cfg,
|
||||||
|
AllParams: allParameters,
|
||||||
|
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||||
|
mcpManifest: mcpManifest,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool represents the create-backup tool.
|
||||||
|
type Tool struct {
|
||||||
|
Config
|
||||||
|
AllParams parameters.Parameters `yaml:"allParams"`
|
||||||
|
manifest tools.Manifest
|
||||||
|
mcpManifest tools.McpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
|
return t.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
|
project, ok := paramsMap["project"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"])
|
||||||
|
}
|
||||||
|
instance, ok := paramsMap["instance"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'instance' parameter: %v", paramsMap["instance"])
|
||||||
|
}
|
||||||
|
|
||||||
|
location, _ := paramsMap["location"].(string)
|
||||||
|
description, _ := paramsMap["backup_description"].(string)
|
||||||
|
|
||||||
|
return source.InsertBackupRun(ctx, project, instance, location, description, string(accessToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseParams parses the parameters for the tool.
|
||||||
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
return parameters.ParseParams(t.AllParams, data, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manifest returns the tool's manifest.
|
||||||
|
func (t Tool) Manifest() tools.Manifest {
|
||||||
|
return t.manifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpManifest returns the tool's MCP manifest.
|
||||||
|
func (t Tool) McpManifest() tools.McpManifest {
|
||||||
|
return t.mcpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorized checks if the tool is authorized.
|
||||||
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.UseClientAuthorization(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||||
|
return "Authorization", nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlcreatebackup_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFromYaml(t *testing.T) {
|
||||||
|
ctx, err := testutils.ContextWithNewLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
tcs := []struct {
|
||||||
|
desc string
|
||||||
|
in string
|
||||||
|
want server.ToolConfigs
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "basic example",
|
||||||
|
in: `
|
||||||
|
tools:
|
||||||
|
create-backup-tool:
|
||||||
|
kind: cloud-sql-create-backup
|
||||||
|
description: a test description
|
||||||
|
source: a-source
|
||||||
|
`,
|
||||||
|
want: server.ToolConfigs{
|
||||||
|
"create-backup-tool": cloudsqlcreatebackup.Config{
|
||||||
|
Name: "create-backup-tool",
|
||||||
|
Kind: "cloud-sql-create-backup",
|
||||||
|
Description: "a test description",
|
||||||
|
Source: "a-source",
|
||||||
|
AuthRequired: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
got := struct {
|
||||||
|
Tools server.ToolConfigs `yaml:"tools"`
|
||||||
|
}{}
|
||||||
|
// Parse contents
|
||||||
|
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unmarshal: %s", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||||
|
t.Fatalf("incorrect parse: diff %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlrestorebackup
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/goccy/go-yaml"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"google.golang.org/api/sqladmin/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const kind string = "cloud-sql-restore-backup"
|
||||||
|
|
||||||
|
var _ tools.ToolConfig = Config{}
|
||||||
|
|
||||||
|
type compatibleSource interface {
|
||||||
|
GetDefaultProject() string
|
||||||
|
GetService(context.Context, string) (*sqladmin.Service, error)
|
||||||
|
UseClientAuthorization() bool
|
||||||
|
RestoreBackup(ctx context.Context, targetProject, targetInstance, sourceProject, sourceInstance, backupID, accessToken string) (any, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config defines the configuration for the restore-backup tool.
|
||||||
|
type Config struct {
|
||||||
|
Name string `yaml:"name" validate:"required"`
|
||||||
|
Kind string `yaml:"kind" validate:"required"`
|
||||||
|
Description string `yaml:"description"`
|
||||||
|
Source string `yaml:"source" validate:"required"`
|
||||||
|
AuthRequired []string `yaml:"authRequired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if !tools.Register(kind, newConfig) {
|
||||||
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||||
|
actual := Config{Name: name}
|
||||||
|
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolConfigKind returns the kind of the tool.
|
||||||
|
func (cfg Config) ToolConfigKind() string {
|
||||||
|
return kind
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize initializes the tool from the configuration.
|
||||||
|
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||||
|
rawS, ok := srcs[cfg.Source]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||||
|
}
|
||||||
|
s, ok := rawS.(compatibleSource)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
|
||||||
|
}
|
||||||
|
|
||||||
|
project := s.GetDefaultProject()
|
||||||
|
var targetProjectParam parameters.Parameter
|
||||||
|
if project != "" {
|
||||||
|
targetProjectParam = parameters.NewStringParameterWithDefault("target_project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
|
||||||
|
} else {
|
||||||
|
targetProjectParam = parameters.NewStringParameter("target_project", "The project ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
allParameters := parameters.Parameters{
|
||||||
|
targetProjectParam,
|
||||||
|
parameters.NewStringParameter("target_instance", "Cloud SQL instance ID of the target instance. This does not include the project ID."),
|
||||||
|
parameters.NewStringParameter("backup_id", "Identifier of the backup being restored. Can be a BackupRun ID, backup name, or BackupDR backup name. Use the full backup ID as provided, do not try to parse it"),
|
||||||
|
parameters.NewStringParameterWithRequired("source_project", "GCP project ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||||
|
parameters.NewStringParameterWithRequired("source_instance", "Cloud SQL instance ID of the instance that the backup belongs to. Only required if the backup_id is a BackupRun ID.", false),
|
||||||
|
}
|
||||||
|
paramManifest := allParameters.Manifest()
|
||||||
|
|
||||||
|
description := cfg.Description
|
||||||
|
if description == "" {
|
||||||
|
description = "Restores a backup on a Cloud SQL instance."
|
||||||
|
}
|
||||||
|
|
||||||
|
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
|
||||||
|
|
||||||
|
return Tool{
|
||||||
|
Config: cfg,
|
||||||
|
AllParams: allParameters,
|
||||||
|
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||||
|
mcpManifest: mcpManifest,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tool represents the restore-backup tool.
|
||||||
|
type Tool struct {
|
||||||
|
Config
|
||||||
|
AllParams parameters.Parameters `yaml:"allParams"`
|
||||||
|
manifest tools.Manifest
|
||||||
|
mcpManifest tools.McpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) ToConfig() tools.ToolConfig {
|
||||||
|
return t.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
|
targetProject, ok := paramsMap["target_project"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'target_project' parameter: %v", paramsMap["target_project"])
|
||||||
|
}
|
||||||
|
targetInstance, ok := paramsMap["target_instance"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'target_instance' parameter: %v", paramsMap["target_instance"])
|
||||||
|
}
|
||||||
|
backupID, ok := paramsMap["backup_id"].(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("error casting 'backup_id' parameter: %v", paramsMap["backup_id"])
|
||||||
|
}
|
||||||
|
sourceProject, _ := paramsMap["source_project"].(string)
|
||||||
|
sourceInstance, _ := paramsMap["source_instance"].(string)
|
||||||
|
|
||||||
|
return source.RestoreBackup(ctx, targetProject, targetInstance, sourceProject, sourceInstance, backupID, string(accessToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseParams parses the parameters for the tool.
|
||||||
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
return parameters.ParseParams(t.AllParams, data, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
|
||||||
|
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manifest returns the tool's manifest.
|
||||||
|
func (t Tool) Manifest() tools.Manifest {
|
||||||
|
return t.manifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// McpManifest returns the tool's MCP manifest.
|
||||||
|
func (t Tool) McpManifest() tools.McpManifest {
|
||||||
|
return t.mcpManifest
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorized checks if the tool is authorized.
|
||||||
|
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||||
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.UseClientAuthorization(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||||
|
return "Authorization", nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
// Copyright 2026 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package cloudsqlrestorebackup_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
yaml "github.com/goccy/go-yaml"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlrestorebackup"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFromYaml(t *testing.T) {
|
||||||
|
ctx, err := testutils.ContextWithNewLogger()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
tcs := []struct {
|
||||||
|
desc string
|
||||||
|
in string
|
||||||
|
want server.ToolConfigs
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "basic example",
|
||||||
|
in: `
|
||||||
|
tools:
|
||||||
|
restore-backup-tool:
|
||||||
|
kind: cloud-sql-restore-backup
|
||||||
|
description: a test description
|
||||||
|
source: a-source
|
||||||
|
`,
|
||||||
|
want: server.ToolConfigs{
|
||||||
|
"restore-backup-tool": cloudsqlrestorebackup.Config{
|
||||||
|
Name: "restore-backup-tool",
|
||||||
|
Kind: "cloud-sql-restore-backup",
|
||||||
|
Description: "a test description",
|
||||||
|
Source: "a-source",
|
||||||
|
AuthRequired: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tcs {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
got := struct {
|
||||||
|
Tools server.ToolConfigs `yaml:"tools"`
|
||||||
|
}{}
|
||||||
|
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to unmarshal: %s", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||||
|
t.Fatalf("incorrect parse: diff %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
|
||||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
CatalogClient() *dataplexapi.CatalogClient
|
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -118,12 +117,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
viewMap := map[int]dataplexpb.EntryView{
|
|
||||||
1: dataplexpb.EntryView_BASIC,
|
|
||||||
2: dataplexpb.EntryView_FULL,
|
|
||||||
3: dataplexpb.EntryView_CUSTOM,
|
|
||||||
4: dataplexpb.EntryView_ALL,
|
|
||||||
}
|
|
||||||
name, _ := paramsMap["name"].(string)
|
name, _ := paramsMap["name"].(string)
|
||||||
entry, _ := paramsMap["entry"].(string)
|
entry, _ := paramsMap["entry"].(string)
|
||||||
view, _ := paramsMap["view"].(int)
|
view, _ := paramsMap["view"].(int)
|
||||||
@@ -132,19 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
||||||
}
|
}
|
||||||
aspectTypes := aspectTypeSlice.([]string)
|
aspectTypes := aspectTypeSlice.([]string)
|
||||||
|
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
|
||||||
req := &dataplexpb.LookupEntryRequest{
|
|
||||||
Name: name,
|
|
||||||
View: viewMap[view],
|
|
||||||
AspectTypes: aspectTypes,
|
|
||||||
Entry: entry,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := source.CatalogClient().LookupEntry(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
|
||||||
"github.com/cenkalti/backoff/v5"
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
@@ -45,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
CatalogClient() *dataplexapi.CatalogClient
|
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
|
||||||
ProjectID() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -101,61 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke the tool with the provided parameters
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
query, _ := paramsMap["query"].(string)
|
query, _ := paramsMap["query"].(string)
|
||||||
pageSize := int32(paramsMap["pageSize"].(int))
|
pageSize, _ := paramsMap["pageSize"].(int)
|
||||||
orderBy, _ := paramsMap["orderBy"].(string)
|
orderBy, _ := paramsMap["orderBy"].(string)
|
||||||
|
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
|
||||||
// Create SearchEntriesRequest with the provided parameters
|
|
||||||
req := &dataplexpb.SearchEntriesRequest{
|
|
||||||
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
|
|
||||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
|
||||||
PageSize: pageSize,
|
|
||||||
OrderBy: orderBy,
|
|
||||||
SemanticSearch: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform the search using the CatalogClient - this will return an iterator
|
|
||||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
|
||||||
if it == nil {
|
|
||||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
|
|
||||||
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
|
|
||||||
getAspectBackOff := backoff.NewExponentialBackOff()
|
|
||||||
|
|
||||||
// Iterate through the search results and call GetAspectType for each result using the resource name
|
|
||||||
var results []*dataplexpb.AspectType
|
|
||||||
for {
|
|
||||||
entry, err := it.Next()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
resourceName := entry.DataplexEntry.GetEntrySource().Resource
|
|
||||||
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
|
|
||||||
Name: resourceName,
|
|
||||||
}
|
|
||||||
|
|
||||||
operation := func() (*dataplexpb.AspectType, error) {
|
|
||||||
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
|
|
||||||
}
|
|
||||||
return aspectType, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retry the GetAspectType operation with exponential backoff
|
|
||||||
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, aspectType)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
@@ -44,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
}
|
}
|
||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
CatalogClient() *dataplexapi.CatalogClient
|
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
|
||||||
ProjectID() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -100,34 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
query, _ := paramsMap["query"].(string)
|
query, _ := paramsMap["query"].(string)
|
||||||
pageSize := int32(paramsMap["pageSize"].(int))
|
pageSize, _ := paramsMap["pageSize"].(int)
|
||||||
orderBy, _ := paramsMap["orderBy"].(string)
|
orderBy, _ := paramsMap["orderBy"].(string)
|
||||||
|
return source.SearchEntries(ctx, query, pageSize, orderBy)
|
||||||
req := &dataplexpb.SearchEntriesRequest{
|
|
||||||
Query: query,
|
|
||||||
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
|
|
||||||
PageSize: pageSize,
|
|
||||||
OrderBy: orderBy,
|
|
||||||
SemanticSearch: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
it := source.CatalogClient().SearchEntries(ctx, req)
|
|
||||||
if it == nil {
|
|
||||||
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []*dataplexpb.SearchEntriesResult
|
|
||||||
for {
|
|
||||||
entry, err := it.Next()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
results = append(results, entry)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -134,24 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
|
|
||||||
// Get collection path
|
// Get collection path
|
||||||
collectionPath, ok := mapParams[collectionPathKey].(string)
|
collectionPath, ok := mapParams[collectionPathKey].(string)
|
||||||
if !ok || collectionPath == "" {
|
if !ok || collectionPath == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate collection path
|
// Validate collection path
|
||||||
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
if err := util.ValidateCollectionPath(collectionPath); err != nil {
|
||||||
return nil, fmt.Errorf("invalid collection path: %w", err)
|
return nil, fmt.Errorf("invalid collection path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get document data
|
// Get document data
|
||||||
documentDataRaw, ok := mapParams[documentDataKey]
|
documentDataRaw, ok := mapParams[documentDataKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert the document data from JSON format to Firestore format
|
// Convert the document data from JSON format to Firestore format
|
||||||
// The client is passed to handle referenceValue types
|
// The client is passed to handle referenceValue types
|
||||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||||
@@ -164,30 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||||
returnData = val
|
returnData = val
|
||||||
}
|
}
|
||||||
|
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
|
||||||
// Get the collection reference
|
|
||||||
collection := source.FirestoreClient().Collection(collectionPath)
|
|
||||||
|
|
||||||
// Add the document to the collection
|
|
||||||
docRef, writeResult, err := collection.Add(ctx, documentData)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add document: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build the response
|
|
||||||
response := map[string]any{
|
|
||||||
"documentPath": docRef.Path,
|
|
||||||
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add document data if requested
|
|
||||||
if returnData {
|
|
||||||
// Convert the document data back to simple JSON format
|
|
||||||
simplifiedData := util.FirestoreValueToJSON(documentData)
|
|
||||||
response["documentData"] = simplifiedData
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
DeleteDocuments(context.Context, []string) ([]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -104,7 +105,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(documentPathsRaw) == 0 {
|
if len(documentPathsRaw) == 0 {
|
||||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||||
}
|
}
|
||||||
@@ -126,45 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.DeleteDocuments(ctx, documentPaths)
|
||||||
// Create a BulkWriter to handle multiple deletions efficiently
|
|
||||||
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
|
|
||||||
|
|
||||||
// Keep track of jobs for each document
|
|
||||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
|
||||||
|
|
||||||
// Add all delete operations to the BulkWriter
|
|
||||||
for i, path := range documentPaths {
|
|
||||||
docRef := source.FirestoreClient().Doc(path)
|
|
||||||
job, err := bulkWriter.Delete(docRef)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
|
||||||
}
|
|
||||||
jobs[i] = job
|
|
||||||
}
|
|
||||||
|
|
||||||
// End the BulkWriter to execute all operations
|
|
||||||
bulkWriter.End()
|
|
||||||
|
|
||||||
// Collect results
|
|
||||||
results := make([]any, len(documentPaths))
|
|
||||||
for i, job := range jobs {
|
|
||||||
docData := make(map[string]any)
|
|
||||||
docData["path"] = documentPaths[i]
|
|
||||||
|
|
||||||
// Wait for the job to complete and get the result
|
|
||||||
_, err := job.Results()
|
|
||||||
if err != nil {
|
|
||||||
docData["success"] = false
|
|
||||||
docData["error"] = err.Error()
|
|
||||||
} else {
|
|
||||||
docData["success"] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
results[i] = docData
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
GetDocuments(context.Context, []string) ([]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -126,37 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return source.GetDocuments(ctx, documentPaths)
|
||||||
// Create document references from paths
|
|
||||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
|
||||||
for i, path := range documentPaths {
|
|
||||||
docRefs[i] = source.FirestoreClient().Doc(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get all documents
|
|
||||||
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert snapshots to response data
|
|
||||||
results := make([]any, len(snapshots))
|
|
||||||
for i, snapshot := range snapshots {
|
|
||||||
docData := make(map[string]any)
|
|
||||||
docData["path"] = documentPaths[i]
|
|
||||||
docData["exists"] = snapshot.Exists()
|
|
||||||
|
|
||||||
if snapshot.Exists() {
|
|
||||||
docData["data"] = snapshot.Data()
|
|
||||||
docData["createTime"] = snapshot.CreateTime
|
|
||||||
docData["updateTime"] = snapshot.UpdateTime
|
|
||||||
docData["readTime"] = snapshot.ReadTime
|
|
||||||
}
|
|
||||||
|
|
||||||
results[i] = docData
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -44,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirebaseRulesClient() *firebaserules.Service
|
FirebaseRulesClient() *firebaserules.Service
|
||||||
GetProjectId() string
|
GetRules(context.Context) (any, error)
|
||||||
GetDatabaseId() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -98,29 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return source.GetRules(ctx)
|
||||||
// Get the latest release for Firestore
|
|
||||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
|
|
||||||
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if release.RulesetName == "" {
|
|
||||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the ruleset content
|
|
||||||
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
|
||||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
|
||||||
}
|
|
||||||
|
|
||||||
return ruleset, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
ListCollections(context.Context, string) ([]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -102,47 +103,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
|
|
||||||
mapParams := params.AsMap()
|
mapParams := params.AsMap()
|
||||||
|
|
||||||
var collectionRefs []*firestoreapi.CollectionRef
|
|
||||||
|
|
||||||
// Check if parentPath is provided
|
// Check if parentPath is provided
|
||||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
parentPath, _ := mapParams[parentPathKey].(string)
|
||||||
|
if parentPath != "" {
|
||||||
if hasParent && parentPath != "" {
|
|
||||||
// Validate parent document path
|
// Validate parent document path
|
||||||
if err := util.ValidateDocumentPath(parentPath); err != nil {
|
if err := util.ValidateDocumentPath(parentPath); err != nil {
|
||||||
return nil, fmt.Errorf("invalid parent document path: %w", err)
|
return nil, fmt.Errorf("invalid parent document path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// List subcollections of the specified document
|
|
||||||
docRef := source.FirestoreClient().Doc(parentPath)
|
|
||||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// List root collections
|
|
||||||
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return source.ListCollections(ctx, parentPath)
|
||||||
// Convert collection references to response data
|
|
||||||
results := make([]any, len(collectionRefs))
|
|
||||||
for i, collRef := range collectionRefs {
|
|
||||||
collData := make(map[string]any)
|
|
||||||
collData["id"] = collRef.ID
|
|
||||||
collData["path"] = collRef.Path
|
|
||||||
|
|
||||||
// If this is a subcollection, include parent information
|
|
||||||
if collRef.Parent != nil {
|
|
||||||
collData["parent"] = collRef.Parent.Path
|
|
||||||
}
|
|
||||||
|
|
||||||
results[i] = collData
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -36,27 +36,6 @@ const (
|
|||||||
defaultLimit = 100
|
defaultLimit = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
// Firestore operators
|
|
||||||
var validOperators = map[string]bool{
|
|
||||||
"<": true,
|
|
||||||
"<=": true,
|
|
||||||
">": true,
|
|
||||||
">=": true,
|
|
||||||
"==": true,
|
|
||||||
"!=": true,
|
|
||||||
"array-contains": true,
|
|
||||||
"array-contains-any": true,
|
|
||||||
"in": true,
|
|
||||||
"not-in": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error messages
|
|
||||||
const (
|
|
||||||
errFilterParseFailed = "failed to parse filters: %w"
|
|
||||||
errQueryExecutionFailed = "failed to execute query: %w"
|
|
||||||
errLimitParseFailed = "failed to parse limit value '%s': %w"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if !tools.Register(kind, newConfig) {
|
if !tools.Register(kind, newConfig) {
|
||||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||||
@@ -74,6 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||||
|
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config represents the configuration for the Firestore query tool
|
// Config represents the configuration for the Firestore query tool
|
||||||
@@ -139,15 +120,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// SimplifiedFilter represents the simplified filter format
|
|
||||||
type SimplifiedFilter struct {
|
|
||||||
And []SimplifiedFilter `json:"and,omitempty"`
|
|
||||||
Or []SimplifiedFilter `json:"or,omitempty"`
|
|
||||||
Field string `json:"field,omitempty"`
|
|
||||||
Op string `json:"op,omitempty"`
|
|
||||||
Value interface{} `json:"value,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrderByConfig represents ordering configuration
|
// OrderByConfig represents ordering configuration
|
||||||
type OrderByConfig struct {
|
type OrderByConfig struct {
|
||||||
Field string `json:"field"`
|
Field string `json:"field"`
|
||||||
@@ -162,20 +134,27 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
|||||||
return firestoreapi.Asc
|
return firestoreapi.Asc
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryResult represents a document result from the query
|
// SimplifiedFilter represents the simplified filter format
|
||||||
type QueryResult struct {
|
type SimplifiedFilter struct {
|
||||||
ID string `json:"id"`
|
And []SimplifiedFilter `json:"and,omitempty"`
|
||||||
Path string `json:"path"`
|
Or []SimplifiedFilter `json:"or,omitempty"`
|
||||||
Data map[string]any `json:"data"`
|
Field string `json:"field,omitempty"`
|
||||||
CreateTime interface{} `json:"createTime,omitempty"`
|
Op string `json:"op,omitempty"`
|
||||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
Value interface{} `json:"value,omitempty"`
|
||||||
ReadTime interface{} `json:"readTime,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryResponse represents the full response including optional metrics
|
// Firestore operators
|
||||||
type QueryResponse struct {
|
var validOperators = map[string]bool{
|
||||||
Documents []QueryResult `json:"documents"`
|
"<": true,
|
||||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
"<=": true,
|
||||||
|
">": true,
|
||||||
|
">=": true,
|
||||||
|
"==": true,
|
||||||
|
"!=": true,
|
||||||
|
"array-contains": true,
|
||||||
|
"array-contains-any": true,
|
||||||
|
"in": true,
|
||||||
|
"not-in": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke executes the Firestore query based on the provided parameters
|
// Invoke executes the Firestore query based on the provided parameters
|
||||||
@@ -184,34 +163,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
paramsMap := params.AsMap()
|
paramsMap := params.AsMap()
|
||||||
|
|
||||||
// Process collection path with template substitution
|
// Process collection path with template substitution
|
||||||
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
|
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to process collection path: %w", err)
|
return nil, fmt.Errorf("failed to process collection path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the query
|
var filter firestoreapi.EntityFilter
|
||||||
query, err := t.buildQuery(source, collectionPath, paramsMap)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Execute the query and return results
|
|
||||||
return t.executeQuery(ctx, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildQuery constructs the Firestore query from parameters
|
|
||||||
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
|
|
||||||
collection := source.FirestoreClient().Collection(collectionPath)
|
|
||||||
query := collection.Query
|
|
||||||
|
|
||||||
// Process and apply filters if template is provided
|
// Process and apply filters if template is provided
|
||||||
if t.Filters != "" {
|
if t.Filters != "" {
|
||||||
// Apply template substitution to filters
|
// Apply template substitution to filters
|
||||||
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
|
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to process filters template: %w", err)
|
return nil, fmt.Errorf("failed to process filters template: %w", err)
|
||||||
}
|
}
|
||||||
@@ -219,48 +182,43 @@ func (t Tool) buildQuery(source compatibleSource, collectionPath string, params
|
|||||||
// Parse the simplified filter format
|
// Parse the simplified filter format
|
||||||
var simplifiedFilter SimplifiedFilter
|
var simplifiedFilter SimplifiedFilter
|
||||||
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
|
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
|
||||||
return nil, fmt.Errorf(errFilterParseFailed, err)
|
return nil, fmt.Errorf("failed to parse filters: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert simplified filter to Firestore filter
|
// Convert simplified filter to Firestore filter
|
||||||
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
|
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
|
||||||
query = query.WhereEntity(filter)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process select fields
|
|
||||||
selectFields, err := t.processSelectFields(params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(selectFields) > 0 {
|
|
||||||
query = query.Select(selectFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process and apply ordering
|
// Process and apply ordering
|
||||||
orderBy, err := t.getOrderBy(params)
|
orderBy, err := t.getOrderBy(paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if orderBy != nil {
|
// Process select fields
|
||||||
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
|
selectFields, err := t.processSelectFields(paramsMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process and apply limit
|
// Process and apply limit
|
||||||
limit, err := t.getLimit(params)
|
limit, err := t.getLimit(paramsMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
// Apply analyze options if enabled
|
// prevent panic when accessing orderBy incase it is nil
|
||||||
if t.AnalyzeQuery {
|
var orderByField string
|
||||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
var orderByDirection firestoreapi.Direction
|
||||||
Analyze: true,
|
if orderBy != nil {
|
||||||
})
|
orderByField = orderBy.Field
|
||||||
|
orderByDirection = orderBy.GetDirection()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &query, nil
|
// Build the query
|
||||||
|
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Execute the query and return results
|
||||||
|
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
|
||||||
@@ -409,7 +367,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
|||||||
if processedValue != "" {
|
if processedValue != "" {
|
||||||
parsedLimit, err := strconv.Atoi(processedValue)
|
parsedLimit, err := strconv.Atoi(processedValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
|
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
|
||||||
}
|
}
|
||||||
limit = parsedLimit
|
limit = parsedLimit
|
||||||
}
|
}
|
||||||
@@ -417,78 +375,6 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
|
|||||||
return limit, nil
|
return limit, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// executeQuery runs the query and formats the results
|
|
||||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
|
|
||||||
docIterator := query.Documents(ctx)
|
|
||||||
docs, err := docIterator.GetAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert results to structured format
|
|
||||||
results := make([]QueryResult, len(docs))
|
|
||||||
for i, doc := range docs {
|
|
||||||
results[i] = QueryResult{
|
|
||||||
ID: doc.Ref.ID,
|
|
||||||
Path: doc.Ref.Path,
|
|
||||||
Data: doc.Data(),
|
|
||||||
CreateTime: doc.CreateTime,
|
|
||||||
UpdateTime: doc.UpdateTime,
|
|
||||||
ReadTime: doc.ReadTime,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return with explain metrics if requested
|
|
||||||
if t.AnalyzeQuery {
|
|
||||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
|
||||||
if err == nil && explainMetrics != nil {
|
|
||||||
response := QueryResponse{
|
|
||||||
Documents: results,
|
|
||||||
ExplainMetrics: explainMetrics,
|
|
||||||
}
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getExplainMetrics extracts explain metrics from the query iterator
|
|
||||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
|
||||||
explainMetrics, err := docIterator.ExplainMetrics()
|
|
||||||
if err != nil || explainMetrics == nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
metricsData := make(map[string]any)
|
|
||||||
|
|
||||||
// Add plan summary if available
|
|
||||||
if explainMetrics.PlanSummary != nil {
|
|
||||||
planSummary := make(map[string]any)
|
|
||||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
|
||||||
metricsData["planSummary"] = planSummary
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add execution stats if available
|
|
||||||
if explainMetrics.ExecutionStats != nil {
|
|
||||||
executionStats := make(map[string]any)
|
|
||||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
|
||||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
|
||||||
|
|
||||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
|
||||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
|
||||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
|
||||||
}
|
|
||||||
|
|
||||||
metricsData["executionStats"] = executionStats
|
|
||||||
}
|
|
||||||
|
|
||||||
return metricsData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseParams parses and validates input parameters
|
// ParseParams parses and validates input parameters
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
return parameters.ParseParams(t.Parameters, data, claims)
|
return parameters.ParseParams(t.Parameters, data, claims)
|
||||||
|
|||||||
@@ -69,7 +69,6 @@ const (
|
|||||||
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
||||||
errMissingFilterValue = "no value specified for filter on field '%s'"
|
errMissingFilterValue = "no value specified for filter on field '%s'"
|
||||||
errOrderByParseFailed = "failed to parse orderBy: %w"
|
errOrderByParseFailed = "failed to parse orderBy: %w"
|
||||||
errQueryExecutionFailed = "failed to execute query: %w"
|
|
||||||
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,6 +89,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
|
||||||
|
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config represents the configuration for the Firestore query collection tool
|
// Config represents the configuration for the Firestore query collection tool
|
||||||
@@ -228,22 +229,6 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
|||||||
return firestoreapi.Asc
|
return firestoreapi.Asc
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryResult represents a document result from the query
|
|
||||||
type QueryResult struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Path string `json:"path"`
|
|
||||||
Data map[string]any `json:"data"`
|
|
||||||
CreateTime interface{} `json:"createTime,omitempty"`
|
|
||||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
|
||||||
ReadTime interface{} `json:"readTime,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryResponse represents the full response including optional metrics
|
|
||||||
type QueryResponse struct {
|
|
||||||
Documents []QueryResult `json:"documents"`
|
|
||||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invoke executes the Firestore query based on the provided parameters
|
// Invoke executes the Firestore query based on the provided parameters
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
@@ -257,14 +242,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var filter firestoreapi.EntityFilter
|
||||||
|
// Apply filters
|
||||||
|
if len(queryParams.Filters) > 0 {
|
||||||
|
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
|
||||||
|
for _, filter := range queryParams.Filters {
|
||||||
|
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||||
|
Path: filter.Field,
|
||||||
|
Operator: filter.Op,
|
||||||
|
Value: filter.Value,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
filter = firestoreapi.AndFilter{
|
||||||
|
Filters: filterConditions,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent panic incase queryParams.OrderBy is nil
|
||||||
|
var orderByField string
|
||||||
|
var orderByDirection firestoreapi.Direction
|
||||||
|
if queryParams.OrderBy != nil {
|
||||||
|
orderByField = queryParams.OrderBy.Field
|
||||||
|
orderByDirection = queryParams.OrderBy.GetDirection()
|
||||||
|
}
|
||||||
|
|
||||||
// Build the query
|
// Build the query
|
||||||
query, err := t.buildQuery(source, queryParams)
|
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||||
// Execute the query and return results
|
|
||||||
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryParameters holds all parsed query parameters
|
// queryParameters holds all parsed query parameters
|
||||||
@@ -380,122 +388,6 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
|||||||
return &orderBy, nil
|
return &orderBy, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildQuery constructs the Firestore query from parameters
|
|
||||||
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
|
|
||||||
collection := source.FirestoreClient().Collection(params.CollectionPath)
|
|
||||||
query := collection.Query
|
|
||||||
|
|
||||||
// Apply filters
|
|
||||||
if len(params.Filters) > 0 {
|
|
||||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
|
||||||
for _, filter := range params.Filters {
|
|
||||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
|
||||||
Path: filter.Field,
|
|
||||||
Operator: filter.Op,
|
|
||||||
Value: filter.Value,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
query = query.WhereEntity(firestoreapi.AndFilter{
|
|
||||||
Filters: filterConditions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply ordering
|
|
||||||
if params.OrderBy != nil {
|
|
||||||
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply limit
|
|
||||||
query = query.Limit(params.Limit)
|
|
||||||
|
|
||||||
// Apply analyze options
|
|
||||||
if params.AnalyzeQuery {
|
|
||||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
|
||||||
Analyze: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return &query, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// executeQuery runs the query and formats the results
|
|
||||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
|
||||||
docIterator := query.Documents(ctx)
|
|
||||||
docs, err := docIterator.GetAll()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert results to structured format
|
|
||||||
results := make([]QueryResult, len(docs))
|
|
||||||
for i, doc := range docs {
|
|
||||||
results[i] = QueryResult{
|
|
||||||
ID: doc.Ref.ID,
|
|
||||||
Path: doc.Ref.Path,
|
|
||||||
Data: doc.Data(),
|
|
||||||
CreateTime: doc.CreateTime,
|
|
||||||
UpdateTime: doc.UpdateTime,
|
|
||||||
ReadTime: doc.ReadTime,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return with explain metrics if requested
|
|
||||||
if analyzeQuery {
|
|
||||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
|
||||||
if err == nil && explainMetrics != nil {
|
|
||||||
response := QueryResponse{
|
|
||||||
Documents: results,
|
|
||||||
ExplainMetrics: explainMetrics,
|
|
||||||
}
|
|
||||||
return response, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return just the documents
|
|
||||||
resultsAny := make([]any, len(results))
|
|
||||||
for i, r := range results {
|
|
||||||
resultsAny[i] = r
|
|
||||||
}
|
|
||||||
return resultsAny, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getExplainMetrics extracts explain metrics from the query iterator
|
|
||||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
|
||||||
explainMetrics, err := docIterator.ExplainMetrics()
|
|
||||||
if err != nil || explainMetrics == nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
metricsData := make(map[string]any)
|
|
||||||
|
|
||||||
// Add plan summary if available
|
|
||||||
if explainMetrics.PlanSummary != nil {
|
|
||||||
planSummary := make(map[string]any)
|
|
||||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
|
||||||
metricsData["planSummary"] = planSummary
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add execution stats if available
|
|
||||||
if explainMetrics.ExecutionStats != nil {
|
|
||||||
executionStats := make(map[string]any)
|
|
||||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
|
||||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
|
||||||
|
|
||||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
|
||||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
|
||||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
|
||||||
}
|
|
||||||
|
|
||||||
metricsData["executionStats"] = executionStats
|
|
||||||
}
|
|
||||||
|
|
||||||
return metricsData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseParams parses and validates input parameters
|
// ParseParams parses and validates input parameters
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
return parameters.ParseParams(t.Parameters, data, claims)
|
return parameters.ParseParams(t.Parameters, data, claims)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirestoreClient() *firestoreapi.Client
|
FirestoreClient() *firestoreapi.Client
|
||||||
|
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -177,23 +178,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Use selective field update with update mask
|
||||||
// Get return document data flag
|
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
||||||
returnData := false
|
var documentData any
|
||||||
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
|
||||||
returnData = val
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the document reference
|
|
||||||
docRef := source.FirestoreClient().Doc(documentPath)
|
|
||||||
|
|
||||||
// Prepare update data
|
|
||||||
var writeResult *firestoreapi.WriteResult
|
|
||||||
var writeErr error
|
|
||||||
|
|
||||||
if len(updatePaths) > 0 {
|
if len(updatePaths) > 0 {
|
||||||
// Use selective field update with update mask
|
|
||||||
updates := make([]firestoreapi.Update, 0, len(updatePaths))
|
|
||||||
|
|
||||||
// Convert document data without delete markers
|
// Convert document data without delete markers
|
||||||
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||||
@@ -220,41 +208,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
Value: value,
|
Value: value,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
writeResult, writeErr = docRef.Update(ctx, updates)
|
|
||||||
} else {
|
} else {
|
||||||
// Update all fields in the document data (merge)
|
// Update all fields in the document data (merge)
|
||||||
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
return nil, fmt.Errorf("failed to convert document data: %w", err)
|
||||||
}
|
}
|
||||||
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if writeErr != nil {
|
// Get return document data flag
|
||||||
return nil, fmt.Errorf("failed to update document: %w", writeErr)
|
returnData := false
|
||||||
|
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
|
||||||
|
returnData = val
|
||||||
}
|
}
|
||||||
|
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
|
||||||
// Build the response
|
|
||||||
response := map[string]any{
|
|
||||||
"documentPath": docRef.Path,
|
|
||||||
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add document data if requested
|
|
||||||
if returnData {
|
|
||||||
// Fetch the updated document to return the current state
|
|
||||||
snapshot, err := docRef.Get(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert the document data to simple JSON format
|
|
||||||
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
|
|
||||||
response["documentData"] = simplifiedData
|
|
||||||
}
|
|
||||||
|
|
||||||
return response, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getFieldValue retrieves a value from a nested map using a dot-separated path
|
// getFieldValue retrieves a value from a nested map using a dot-separated path
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ package firestorevalidaterules
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
@@ -50,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
|
|
||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
FirebaseRulesClient() *firebaserules.Service
|
FirebaseRulesClient() *firebaserules.Service
|
||||||
GetProjectId() string
|
ValidateRules(context.Context, string) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -107,30 +106,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
|
|||||||
return t.Config
|
return t.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// Issue represents a validation issue in the rules
|
|
||||||
type Issue struct {
|
|
||||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Severity string `json:"severity"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SourcePosition represents the location of an issue in the source
|
|
||||||
type SourcePosition struct {
|
|
||||||
FileName string `json:"fileName,omitempty"`
|
|
||||||
Line int64 `json:"line"` // 1-based
|
|
||||||
Column int64 `json:"column"` // 1-based
|
|
||||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
|
||||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidationResult represents the result of rules validation
|
|
||||||
type ValidationResult struct {
|
|
||||||
Valid bool `json:"valid"`
|
|
||||||
IssueCount int `json:"issueCount"`
|
|
||||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
|
||||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||||
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -144,114 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
if !ok || sourceParam == "" {
|
if !ok || sourceParam == "" {
|
||||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||||
}
|
}
|
||||||
|
return source.ValidateRules(ctx, sourceParam)
|
||||||
// Create test request
|
|
||||||
testRequest := &firebaserules.TestRulesetRequest{
|
|
||||||
Source: &firebaserules.Source{
|
|
||||||
Files: []*firebaserules.File{
|
|
||||||
{
|
|
||||||
Name: "firestore.rules",
|
|
||||||
Content: sourceParam,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
// We don't need test cases for validation only
|
|
||||||
TestSuite: &firebaserules.TestSuite{
|
|
||||||
TestCases: []*firebaserules.TestCase{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call the test API
|
|
||||||
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
|
|
||||||
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the response
|
|
||||||
result := t.processValidationResponse(response, sourceParam)
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
|
|
||||||
if len(response.Issues) == 0 {
|
|
||||||
return ValidationResult{
|
|
||||||
Valid: true,
|
|
||||||
IssueCount: 0,
|
|
||||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert issues to our format
|
|
||||||
issues := make([]Issue, len(response.Issues))
|
|
||||||
for i, issue := range response.Issues {
|
|
||||||
issues[i] = Issue{
|
|
||||||
Description: issue.Description,
|
|
||||||
Severity: issue.Severity,
|
|
||||||
SourcePosition: SourcePosition{
|
|
||||||
FileName: issue.SourcePosition.FileName,
|
|
||||||
Line: issue.SourcePosition.Line,
|
|
||||||
Column: issue.SourcePosition.Column,
|
|
||||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
|
||||||
EndOffset: issue.SourcePosition.EndOffset,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format issues
|
|
||||||
formattedIssues := t.formatRulesetIssues(issues, source)
|
|
||||||
|
|
||||||
return ValidationResult{
|
|
||||||
Valid: false,
|
|
||||||
IssueCount: len(issues),
|
|
||||||
FormattedIssues: formattedIssues,
|
|
||||||
RawIssues: issues,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
|
|
||||||
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
|
|
||||||
sourceLines := strings.Split(rulesSource, "\n")
|
|
||||||
var formattedOutput []string
|
|
||||||
|
|
||||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
|
||||||
|
|
||||||
for _, issue := range issues {
|
|
||||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
|
||||||
issue.Severity,
|
|
||||||
issue.Description,
|
|
||||||
issue.SourcePosition.Line,
|
|
||||||
issue.SourcePosition.Column)
|
|
||||||
|
|
||||||
if issue.SourcePosition.Line > 0 {
|
|
||||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
|
||||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
|
||||||
errorLine := sourceLines[lineIndex]
|
|
||||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
|
||||||
|
|
||||||
// Add carets if we have column and offset information
|
|
||||||
if issue.SourcePosition.Column > 0 &&
|
|
||||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
|
||||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
|
||||||
|
|
||||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
|
||||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
|
||||||
|
|
||||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
|
||||||
padding := strings.Repeat(" ", startColumn)
|
|
||||||
carets := strings.Repeat("^", errorTokenLength)
|
|
||||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
issueString += "\n```"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
formattedOutput = append(formattedOutput, issueString)
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(formattedOutput, "\n\n")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -28,13 +28,13 @@ import (
|
|||||||
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
|
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
|
||||||
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
|
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
|
||||||
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
|
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
|
||||||
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
|
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]any:
|
||||||
// Check for typed values
|
// Check for typed values
|
||||||
if len(v) == 1 {
|
if len(v) == 1 {
|
||||||
for key, val := range v {
|
for key, val := range v {
|
||||||
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
|||||||
return nil, fmt.Errorf("timestamp value must be a string")
|
return nil, fmt.Errorf("timestamp value must be a string")
|
||||||
case "geoPointValue":
|
case "geoPointValue":
|
||||||
// Convert to LatLng
|
// Convert to LatLng
|
||||||
if geoMap, ok := val.(map[string]interface{}); ok {
|
if geoMap, ok := val.(map[string]any); ok {
|
||||||
lat, latOk := geoMap["latitude"].(float64)
|
lat, latOk := geoMap["latitude"].(float64)
|
||||||
lng, lngOk := geoMap["longitude"].(float64)
|
lng, lngOk := geoMap["longitude"].(float64)
|
||||||
if latOk && lngOk {
|
if latOk && lngOk {
|
||||||
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
|||||||
return nil, fmt.Errorf("invalid geopoint value format")
|
return nil, fmt.Errorf("invalid geopoint value format")
|
||||||
case "arrayValue":
|
case "arrayValue":
|
||||||
// Convert array
|
// Convert array
|
||||||
if arrayMap, ok := val.(map[string]interface{}); ok {
|
if arrayMap, ok := val.(map[string]any); ok {
|
||||||
if values, ok := arrayMap["values"].([]interface{}); ok {
|
if values, ok := arrayMap["values"].([]any); ok {
|
||||||
result := make([]interface{}, len(values))
|
result := make([]any, len(values))
|
||||||
for i, item := range values {
|
for i, item := range values {
|
||||||
converted, err := JSONToFirestoreValue(item, client)
|
converted, err := JSONToFirestoreValue(item, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
|||||||
return nil, fmt.Errorf("invalid array value format")
|
return nil, fmt.Errorf("invalid array value format")
|
||||||
case "mapValue":
|
case "mapValue":
|
||||||
// Convert map
|
// Convert map
|
||||||
if mapMap, ok := val.(map[string]interface{}); ok {
|
if mapMap, ok := val.(map[string]any); ok {
|
||||||
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
|
if fields, ok := mapMap["fields"].(map[string]any); ok {
|
||||||
result := make(map[string]interface{})
|
result := make(map[string]any)
|
||||||
for k, v := range fields {
|
for k, v := range fields {
|
||||||
converted, err := JSONToFirestoreValue(v, client)
|
converted, err := JSONToFirestoreValue(v, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
|
|||||||
}
|
}
|
||||||
|
|
||||||
// convertPlainMap converts a plain map to Firestore format
|
// convertPlainMap converts a plain map to Firestore format
|
||||||
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
|
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
|
||||||
result := make(map[string]interface{})
|
result := make(map[string]any)
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
converted, err := JSONToFirestoreValue(v, client)
|
converted, err := JSONToFirestoreValue(v, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -172,42 +172,6 @@ func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[st
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
|
|
||||||
// This removes type information and returns plain values
|
|
||||||
func FirestoreValueToJSON(value interface{}) interface{} {
|
|
||||||
if value == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := value.(type) {
|
|
||||||
case time.Time:
|
|
||||||
return v.Format(time.RFC3339Nano)
|
|
||||||
case *latlng.LatLng:
|
|
||||||
return map[string]interface{}{
|
|
||||||
"latitude": v.Latitude,
|
|
||||||
"longitude": v.Longitude,
|
|
||||||
}
|
|
||||||
case []byte:
|
|
||||||
return base64.StdEncoding.EncodeToString(v)
|
|
||||||
case []interface{}:
|
|
||||||
result := make([]interface{}, len(v))
|
|
||||||
for i, item := range v {
|
|
||||||
result[i] = FirestoreValueToJSON(item)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
case map[string]interface{}:
|
|
||||||
result := make(map[string]interface{})
|
|
||||||
for k, val := range v {
|
|
||||||
result[k] = FirestoreValueToJSON(val)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
case *firestore.DocumentRef:
|
|
||||||
return v.Path
|
|
||||||
default:
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// isValidDocumentPath checks if a string is a valid Firestore document path
|
// isValidDocumentPath checks if a string is a valid Firestore document path
|
||||||
// Valid paths have an even number of segments (collection/doc/collection/doc...)
|
// Valid paths have an even number of segments (collection/doc/collection/doc...)
|
||||||
func isValidDocumentPath(path string) bool {
|
func isValidDocumentPath(path string) bool {
|
||||||
|
|||||||
@@ -312,40 +312,6 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
|
|
||||||
// Test round-trip conversion
|
|
||||||
original := map[string]interface{}{
|
|
||||||
"name": "Test",
|
|
||||||
"count": int64(42),
|
|
||||||
"price": 19.99,
|
|
||||||
"active": true,
|
|
||||||
"tags": []interface{}{"tag1", "tag2"},
|
|
||||||
"metadata": map[string]interface{}{
|
|
||||||
"created": time.Now(),
|
|
||||||
},
|
|
||||||
"nullField": nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to JSON representation
|
|
||||||
jsonRepresentation := FirestoreValueToJSON(original)
|
|
||||||
|
|
||||||
// Verify types are simplified
|
|
||||||
jsonMap, ok := jsonRepresentation.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("Expected map, got %T", jsonRepresentation)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Time should be converted to string
|
|
||||||
metadata, ok := jsonMap["metadata"].(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
|
|
||||||
}
|
|
||||||
_, ok = metadata["created"].(string)
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("created should be a string, got %T", metadata["created"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
|
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -16,9 +16,7 @@ package http
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -54,7 +52,7 @@ type compatibleSource interface {
|
|||||||
HttpDefaultHeaders() map[string]string
|
HttpDefaultHeaders() map[string]string
|
||||||
HttpBaseURL() string
|
HttpBaseURL() string
|
||||||
HttpQueryParams() map[string]string
|
HttpQueryParams() map[string]string
|
||||||
Client() *http.Client
|
RunRequest(*http.Request) (any, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -259,29 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
for k, v := range allHeaders {
|
for k, v := range allHeaders {
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
return source.RunRequest(req)
|
||||||
// Make request and fetch response
|
|
||||||
resp, err := source.Client().Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error making HTTP request: %s", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
var body []byte
|
|
||||||
body, err = io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
|
||||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
var data any
|
|
||||||
if err = json.Unmarshal(body, &data); err != nil {
|
|
||||||
// if unable to unmarshal data, return result as string.
|
|
||||||
return string(body), nil
|
|
||||||
}
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -159,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
visConfig := paramsMap["vis_config"].(map[string]any)
|
visConfig := paramsMap["vis_config"].(map[string]any)
|
||||||
wq.VisConfig = &visConfig
|
wq.VisConfig = &visConfig
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -192,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
req.Dimension = &dimension
|
req.Dimension = &dimension
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,65 +15,17 @@ package lookercommon
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
rtl "github.com/looker-open-source/sdk-codegen/go/rtl"
|
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||||
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
||||||
"github.com/thlib/go-timezone-local/tzlocal"
|
"github.com/thlib/go-timezone-local/tzlocal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Make types for RoundTripper
|
|
||||||
type transportWithAuthHeader struct {
|
|
||||||
Base http.RoundTripper
|
|
||||||
AuthToken tools.AccessToken
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
req.Header.Set("x-looker-appid", "go-sdk")
|
|
||||||
req.Header.Set("Authorization", string(t.AuthToken))
|
|
||||||
return t.Base.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetLookerSDK(useClientOAuth bool, config *rtl.ApiSettings, client *v4.LookerSDK, accessToken tools.AccessToken) (*v4.LookerSDK, error) {
|
|
||||||
|
|
||||||
if useClientOAuth {
|
|
||||||
if accessToken == "" {
|
|
||||||
return nil, fmt.Errorf("no access token supplied with request")
|
|
||||||
}
|
|
||||||
|
|
||||||
session := rtl.NewAuthSession(*config)
|
|
||||||
// Configure base transport with TLS
|
|
||||||
transport := &http.Transport{
|
|
||||||
TLSClientConfig: &tls.Config{
|
|
||||||
InsecureSkipVerify: !config.VerifySsl,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build transport for end user token
|
|
||||||
session.Client = http.Client{
|
|
||||||
Transport: &transportWithAuthHeader{
|
|
||||||
Base: transport,
|
|
||||||
AuthToken: accessToken,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// return SDK with new Transport
|
|
||||||
return v4.NewLookerSDK(session), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if client == nil {
|
|
||||||
return nil, fmt.Errorf("client id or client secret not valid")
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DimensionsFields = "fields(dimensions(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
DimensionsFields = "fields(dimensions(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
||||||
FiltersFields = "fields(filters(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
FiltersFields = "fields(filters(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -116,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -117,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -125,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
|
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -49,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerSessionLength() int64
|
LookerSessionLength() int64
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
contentId_ptr = nil
|
contentId_ptr = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||||
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
|
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||||
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -122,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
db, _ := mapParams["db"].(string)
|
db, _ := mapParams["db"].(string)
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"])
|
return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"])
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -132,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"])
|
return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"])
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
limit := int64(paramsMap["limit"].(int))
|
limit := int64(paramsMap["limit"].(int))
|
||||||
offset := int64(paramsMap["offset"].(int))
|
offset := int64(paramsMap["offset"].(int))
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenFields() bool
|
LookerShowHiddenFields() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("error processing model or explore: %w", err)
|
return nil, fmt.Errorf("error processing model or explore: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenExplores() bool
|
LookerShowHiddenExplores() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"])
|
return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"])
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenFields() bool
|
LookerShowHiddenFields() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
fields := lookercommon.FiltersFields
|
fields := lookercommon.FiltersFields
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
limit := int64(paramsMap["limit"].(int))
|
limit := int64(paramsMap["limit"].(int))
|
||||||
offset := int64(paramsMap["offset"].(int))
|
offset := int64(paramsMap["offset"].(int))
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenFields() bool
|
LookerShowHiddenFields() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
fields := lookercommon.MeasuresFields
|
fields := lookercommon.MeasuresFields
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenModels() bool
|
LookerShowHiddenModels() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
excludeHidden := !source.LookerShowHiddenModels()
|
excludeHidden := !source.LookerShowHiddenModels()
|
||||||
includeInternal := true
|
includeInternal := true
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
LookerShowHiddenFields() bool
|
LookerShowHiddenFields() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
}
|
}
|
||||||
|
|
||||||
fields := lookercommon.ParametersFields
|
fields := lookercommon.ParametersFields
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -121,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ import (
|
|||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
|
||||||
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
GetAuthTokenHeaderName() string
|
GetAuthTokenHeaderName() string
|
||||||
LookerClient() *v4.LookerSDK
|
|
||||||
LookerApiSettings() *rtl.ApiSettings
|
LookerApiSettings() *rtl.ApiSettings
|
||||||
|
GetLookerSDK(string) (*v4.LookerSDK, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
|
sdk, err := source.GetLookerSDK(string(accessToken))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error getting sdk: %w", err)
|
return nil, fmt.Errorf("error getting sdk: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user