Compare commits

..

15 Commits

Author SHA1 Message Date
Juexin Wang
172dd6b19b refactor(tools/cloudgda): update to use google-cloud-go sdk types
Removes types.go and uses geminidataanalyticspb types with wrapper structs for YAML decoding.
2026-01-14 17:06:43 -08:00
Juexin Wang
9f5b04cf73 Refactor Cloud GDA source to support per-request client authorization 2026-01-14 16:10:50 -08:00
Juexin Wang
66d6b58c4f intorduce Go SDK for GDA tool and source 2026-01-14 16:10:50 -08:00
Juexin Wang
6b02591703 refactor(tools/cloudgda)!: update description and parameter name for cloudgda tool (#2288)
- Refactors the 'cloud-gemini-data-analytics-query' tool to update its
default description with detailed tool guidance and usage guidance.
- Append the default description to the tools.yaml description no matter
whether the tools.yaml description exists since this guidance will
always be useful to the agent on how to use the tool.
- Renames the input parameter from 'prompt' to 'query' for better
consistency.
2026-01-14 23:54:43 +00:00
Eric Wang
8e0fb03483 feat(prebuilt/cloud-sql): Add create backup tool for Cloud SQL (#2141)
## Description

This pull request adds a new tool, cloud-sql-create-backup, which
enables taking a backup on a Cloud SQL instance from the toolbox using
the Cloud SQL Admin API. The tool supports optionally supplying a
location or description for the backup.

Tested:
<img width="1561" height="425" alt="image"
src="https://github.com/user-attachments/assets/c8984b07-5450-470a-9ac6-df16943e25e9"
/>


## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #2140

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
Co-authored-by: Averi Kitsch <akitsch@google.com>
2026-01-14 22:55:11 +00:00
Wenxin Du
2817dd1e5d docs: Add PR guidelines in 'CONTRIBUTING.md' (#2304)
Add PR guidelines - some best practices
2026-01-14 19:08:56 +00:00
Giuseppe Villani
68a218407e docs: add quickstart guide for MCP with Neo4j (#1774)
## Description

Samples for MCP with Neo4j for this page:
https://googleapis.github.io/genai-toolbox/samples/

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

---------

Co-authored-by: Yuan Teoh <45984206+Yuan325@users.noreply.github.com>
2026-01-14 01:57:10 +00:00
Sahaja Reddy Pabbathi Reddy
d69792d843 chore: update dataplex aspecttypes integration tests (#2193)
## Description

Addresses an issue where Dataplex AspectTypes created during integration
tests were not consistently deleted. This accumulation led to exceeding
the AspectType quota.
To prevent this, the test setup now includes a step to list and delete
all existing AspectTypes within the test project and location *before*
attempting to create any new ones. This ensures a clean state for each
test run and avoids hitting the quota.


## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [ ] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #2057

Co-authored-by: Averi Kitsch <akitsch@google.com>
2026-01-13 08:24:33 -08:00
Yuan Teoh
647b04d3a7 docs(tools/alloydbainl): only require psv installation when needed (#2283)
Update existing docs to install PSV extensions **only** when needed.
This tool could be use without installing PSV if NLConfigParam is not
used.

---------

Co-authored-by: Averi Kitsch <akitsch@google.com>
2026-01-12 16:57:59 -08:00
Yuan Teoh
030df9766f refactor(sources/looker): move source implementation in Invoke() function into Source (#2278)
Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2026-01-12 23:53:33 +00:00
Yuan Teoh
5dbf207162 refactor(sources/mongodb): move source implementation in Invoke() function into Source (#2277)
Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2026-01-12 15:25:39 -08:00
洪鈞閔 ( jasper )
9c3720e31d docs(README.md): fix container section version 0.11.0 => 0.24.0 (#2251)
## Description

> Should include a concise description of the changes (bug or feature),
it's
> impact, along with a summary of the solution

## PR Checklist

> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [ ] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
  before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [ ] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>

Co-authored-by: Averi Kitsch <akitsch@google.com>
2026-01-12 20:17:15 +00:00
Yuan Teoh
3cd3c39d66 refactor(sources/firestore): move source implementation in Invoke() function into Source (#2275)
Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2026-01-12 18:43:35 +00:00
Yuan Teoh
0691a6f715 refactor: move source implementation in Invoke() function to Source (#2274)
Move source-related queries from `Invoke()` function into Source.

This PR addresses the following sources:
* dataplex
* http
* serverlessspark

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2026-01-12 18:16:32 +00:00
Yuan Teoh
467b96a23b refactor(sources/cloudhealthcare): move source implementation in Invoke() function to Source (#2273)
Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation
2026-01-12 17:51:58 +00:00
120 changed files with 2724 additions and 2422 deletions

View File

@@ -59,6 +59,13 @@ You can manually trigger the bot by commenting on your Pull Request:
* `/gemini summary`: Posts a summary of the changes in the pull request.
* `/gemini help`: Overview of the available commands
## Guidelines for Pull Requests
1. Please keep your PR small for more thorough review and easier updates. In case of regression, it also allows us to roll back a single feature instead of multiple ones.
1. For non-trivial changes, consider opening an issue and discussing it with the code owners first.
1. Provide a good PR description as a record of what change is being made and why it was made. Link to a GitHub issue if it exists.
1. Make sure your code is thoroughly tested with unit tests and integration tests. Remember to clean up the test instances properly in your code to avoid memory leaks.
## Adding a New Database Source or Tool
Please create an
@@ -110,6 +117,8 @@ implementation](https://github.com/googleapis/genai-toolbox/blob/main/internal/s
We recommend looking at an [example tool
implementation](https://github.com/googleapis/genai-toolbox/tree/main/internal/tools/postgres/postgressql).
Remember to keep your PRs small. For example, if you are contributing a new Source, only include one or two core Tools within the same PR, the rest of the Tools can come in subsequent PRs.
* **Create a new directory** under `internal/tools` for your tool type (e.g., `internal/tools/newdb/newdbtool`).
* **Define a configuration struct** for your tool in a file named `newdbtool.go`.
Create a `Config` struct and a `Tool` struct to store necessary parameters for
@@ -163,6 +172,8 @@ tools.
parameters][temp-param-doc]. Only run this test if template
parameters apply to your tool.
* **Add additional tests** for the tools that are not covered by the predefined tests. Every tool must be tested!
* **Add the new database to the integration test workflow** in
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
@@ -244,4 +255,4 @@ resources.
* **PR Description:** PR description should **always** be included. It should
include a concise description of the changes, it's impact, along with a
summary of the solution. If the PR is related to a specific issue, the issue
number should be mentioned in the PR description (e.g. `Fixes #1`).
number should be mentioned in the PR description (e.g. `Fixes #1`).

View File

@@ -272,7 +272,7 @@ To run Toolbox from binary:
To run the server after pulling the [container image](#installing-the-server):
```sh
export VERSION=0.11.0 # Use the version you pulled
export VERSION=0.24.0 # Use the version you pulled
docker run -p 5000:5000 \
-v $(pwd)/tools.yaml:/app/tools.yaml \
us-central1-docker.pkg.dev/database-toolbox/toolbox/toolbox:$VERSION \

View File

@@ -92,6 +92,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcaresearchdicomstudies"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcloneinstance"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatedatabase"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreateusers"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlgetinstances"

View File

@@ -1493,7 +1493,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "postgres_upgrade_precheck", "clone_instance", "create_backup"},
},
},
},
@@ -1503,7 +1503,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mysql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mysql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
},
},
},
@@ -1513,7 +1513,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_mssql_admin_tools": tools.ToolsetConfig{
Name: "cloud_sql_mssql_admin_tools",
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance"},
ToolNames: []string{"create_instance", "get_instance", "list_instances", "create_database", "list_databases", "create_user", "wait_for_operation", "clone_instance", "create_backup"},
},
},
},

View File

@@ -48,6 +48,7 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -299,6 +300,7 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for SQL Server instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,6 +48,7 @@ database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -299,6 +300,7 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for MySQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -48,6 +48,7 @@ instance, database and users:
* `roles/cloudsql.editor`: Provides permissions to manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* `roles/cloudsql.admin`: Provides full control over all resources.
* All `editor` and `viewer` tools
* `create_instance`
@@ -299,6 +300,7 @@ instances and interacting with your database:
* **create_user**: Creates a new user in a Cloud SQL instance.
* **wait_for_operation**: Waits for a Cloud SQL operation to complete.
* **clone_instance**: Creates a clone of an existing Cloud SQL for PostgreSQL instance.
* **create_backup**: Creates a backup on a Cloud SQL instance.
{{< notice note >}}
Prebuilt tools are pre-1.0, so expect some tool changes between versions. LLMs

View File

@@ -101,9 +101,6 @@ authentication token.
TOOLBOX_URL = "https://your-toolbox-service-xyz.a.run.app"
# Initialize the client with the Cloud Run URL and Auth headers
# WARNING: There is a known issue that may cause it to fail with a 401 error
# on local machines.
# See https://github.com/googleapis/mcp-toolbox-sdk-python/issues/496
client = ToolboxSyncClient(
TOOLBOX_URL,
client_headers={"Authorization": auth_methods.get_google_id_token(TOOLBOX_URL)}

View File

@@ -211,9 +211,6 @@ from toolbox_core import ToolboxClient, auth_methods
# Replace with the Cloud Run service URL generated in the previous step
URL = "https://cloud-run-url.app"
# WARNING: There is a known issue that may cause it to fail with a 401 error
# on local machines.
# See https://github.com/googleapis/mcp-toolbox-sdk-python/issues/496
auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method
async def main():

View File

@@ -187,6 +187,7 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -203,6 +204,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for MySQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for PostgreSQL
@@ -275,6 +277,7 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -290,6 +293,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for PostgreSQL instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Cloud SQL for SQL Server
@@ -336,6 +340,7 @@ See [Usage Examples](../reference/cli.md#examples).
manage existing resources.
* All `viewer` tools
* `create_database`
* `create_backup`
* **Cloud SQL Admin** (`roles/cloudsql.admin`): Provides full control over
all resources.
* All `editor` and `viewer` tools
@@ -351,6 +356,7 @@ See [Usage Examples](../reference/cli.md#examples).
* `create_user`: Creates a new user in a Cloud SQL instance.
* `wait_for_operation`: Waits for a Cloud SQL operation to complete.
* `clone_instance`: Creates a clone for an existing Cloud SQL for SQL Server instance.
* `create_backup`: Creates a backup on a Cloud SQL instance.
## Dataplex

View File

@@ -91,8 +91,8 @@ visible to the LLM.
https://cloud.google.com/alloydb/docs/parameterized-secure-views-overview
{{< notice tip >}} Make sure to enable the `parameterized_views` extension
before running this tool. You can do so by running this command in the AlloyDB
studio:
to utilize PSV feature (`nlConfigParameters`) with this tool. You can do so by
running this command in the AlloyDB studio:
```sql
CREATE EXTENSION IF NOT EXISTS parameterized_views;

View File

@@ -41,13 +41,13 @@ tools:
### Usage Flow
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
**Example Input Prompt:**
**Example Input Query:**
```text
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.

View File

@@ -0,0 +1,45 @@
---
title: cloud-sql-create-backup
type: docs
weight: 10
description: "Creates a backup on a Cloud SQL instance."
---
The `cloud-sql-create-backup` tool creates an on-demand backup on a Cloud SQL instance using the Cloud SQL Admin API.
{{< notice info dd>}}
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Examples
Basic backup creation (current state)
```yaml
tools:
backup-creation-basic:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
description: "Creates a backup on the given Cloud SQL instance."
```
## Reference
### Tool Configuration
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ------------------------------------------------------------- |
| kind | string | true | Must be "cloud-sql-create-backup". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |
### Tool Inputs
| **parameter** | **type** | **required** | **description** |
| -------------------------- | :------: | :----------: | ------------------------------------------------------------------------------- |
| project | string | true | The project ID. |
| instance | string | true | The name of the instance to take a backup on. Does not include the project ID. |
| location | string | false | (Optional) Location of the backup run. |
| backup_description | string | false | (Optional) The description of this backup run. |
## See Also
- [Cloud SQL Admin API documentation](https://cloud.google.com/sql/docs/mysql/admin-api)
- [Toolbox Cloud SQL tools documentation](../cloudsql)
- [Cloud SQL Backup API documentation](https://cloud.google.com/sql/docs/mysql/backup-recovery/backups)

View File

@@ -0,0 +1,7 @@
---
title: "Neo4j"
type: docs
weight: 1
description: >
How to get started with Toolbox using Neo4j.
---

View File

@@ -0,0 +1,141 @@
---
title: "Quickstart (MCP with Neo4j)"
type: docs
weight: 1
description: >
How to get started running Toolbox with MCP Inspector and Neo4j as the source.
---
## Overview
[Model Context Protocol](https://modelcontextprotocol.io) is an open protocol that standardizes how applications provide context to LLMs. Check out this page on how to [connect to Toolbox via MCP](../../how-to/connect_via_mcp.md).
## Step 1: Set up your Neo4j Database and Data
In this section, you'll set up a database and populate it with sample data for a movies-related agent. This guide assumes you have a running Neo4j instance, either locally or in the cloud.
. **Populate the database with data.**
To make this quickstart straightforward, we'll use the built-in Movies dataset available in Neo4j.
. In your Neo4j Browser, run the following command to create and populate the database:
+
```cypher
:play movies
````
. Follow the instructions to load the data. This will create a graph with `Movie`, `Person`, and `Actor` nodes and their relationships.
## Step 2: Install and configure Toolbox
In this section, we will install the MCP Toolbox, configure our tools in a `tools.yaml` file, and then run the Toolbox server.
. **Install the Toolbox binary.**
The simplest way to get started is to download the latest binary for your operating system.
. Download the latest version of Toolbox as a binary:
\+
```bash
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
curl -O [https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox](https://storage.googleapis.com/genai-toolbox/v0.16.0/$OS/toolbox)
```
+
. Make the binary executable:
\+
```bash
chmod +x toolbox
```
. **Create the `tools.yaml` file.**
This file defines your Neo4j source and the specific tools that will be exposed to your AI agent.
\+
{{\< notice tip \>}}
Authentication for the Neo4j source uses standard username and password fields. For production use, it is highly recommended to use environment variables for sensitive information like passwords.
{{\< /notice \>}}
\+
Write the following into a `tools.yaml` file:
\+
```yaml
sources:
my-neo4j-source:
kind: neo4j
uri: bolt://localhost:7687
user: neo4j
password: my-password # Replace with your actual password
tools:
search-movies-by-actor:
kind: neo4j-cypher
source: my-neo4j-source
description: "Searches for movies an actor has appeared in based on their name. Useful for questions like 'What movies has Tom Hanks been in?'"
parameters:
- name: actor_name
type: string
description: The full name of the actor to search for.
statement: |
MATCH (p:Person {name: $actor_name}) -[:ACTED_IN]-> (m:Movie)
RETURN m.title AS title, m.year AS year, m.genre AS genre
get-actor-for-movie:
kind: neo4j-cypher
source: my-neo4j-source
description: "Finds the actors who starred in a specific movie. Useful for questions like 'Who acted in Inception?'"
parameters:
- name: movie_title
type: string
description: The exact title of the movie.
statement: |
MATCH (p:Person) -[:ACTED_IN]-> (m:Movie {title: $movie_title})
RETURN p.name AS actor
```
. **Start the Toolbox server.**
Run the Toolbox server, pointing to the `tools.yaml` file you created earlier.
\+
```bash
./toolbox --tools-file "tools.yaml"
```
## Step 3: Connect to MCP Inspector
. **Run the MCP Inspector:**
\+
```bash
npx @modelcontextprotocol/inspector
```
. Type `y` when it asks to install the inspector package.
. It should show the following when the MCP Inspector is up and running (please take note of `<YOUR_SESSION_TOKEN>`):
\+
```bash
Starting MCP inspector...
⚙️ Proxy server listening on localhost:6277
🔑 Session token: <YOUR_SESSION_TOKEN>
Use this token to authenticate requests or set DANGEROUSLY_OMIT_AUTH=true to disable auth
🚀 MCP Inspector is up and running at:
http://localhost:6274/?MCP_PROXY_AUTH_TOKEN=<YOUR_SESSION_TOKEN>
```
1. Open the above link in your browser.
1. For `Transport Type`, select `Streamable HTTP`.
1. For `URL`, type in `http://127.0.0.1:5000/mcp`.
1. For `Configuration` -\> `Proxy Session Token`, make sure `<YOUR_SESSION_TOKEN>` is present.
1. Click `Connect`.
1. Select `List Tools`, you will see a list of tools configured in `tools.yaml`.
1. Test out your tools here\!

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3

4
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

View File

@@ -43,6 +43,9 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mssql_admin_tools:
@@ -54,3 +57,4 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -43,6 +43,9 @@ tools:
clone_instance:
kind: cloud-sql-clone-instance
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_mysql_admin_tools:
@@ -54,3 +57,4 @@ toolsets:
- create_user
- wait_for_operation
- clone_instance
- create_backup

View File

@@ -46,6 +46,9 @@ tools:
postgres_upgrade_precheck:
kind: postgres-upgrade-precheck
source: cloud-sql-admin-source
create_backup:
kind: cloud-sql-create-backup
source: cloud-sql-admin-source
toolsets:
cloud_sql_postgres_admin_tools:
@@ -58,3 +61,4 @@ toolsets:
- wait_for_operation
- postgres_upgrade_precheck
- clone_instance
- create_backup

View File

@@ -14,23 +14,20 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -67,29 +64,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
if !r.UseClientOAuth {
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}
return s, nil
}
@@ -97,8 +84,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
BaseURL string
Client *geminidataanalytics.DataChatClient
userAgent string
}
@@ -114,63 +100,34 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
return nil, err
}
defer cleanup()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
return client.QueryData(ctx, req)
}
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
client, err := geminidataanalytics.NewDataChatClient(ctx,
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
}
return s.Client, func() {}, nil
}

View File

@@ -181,11 +181,9 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
}
// Test UseClientAuthorization method
@@ -195,15 +193,16 @@ func TestInitialize(t *testing.T) {
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
_, _, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}

View File

@@ -16,8 +16,12 @@ package cloudhealthcare
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -255,3 +259,299 @@ func (s *Source) IsDICOMStoreAllowed(storeID string) bool {
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func parseResults(resp *http.Response) (any, error) {
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal(respBytes, &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
}
func (s *Source) getService(tokenStr string) (*healthcare.Service, error) {
svc := s.Service()
var err error
// Initialize new service if using user OAuth token
if s.UseClientAuthorization() {
svc, err = s.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
return svc, nil
}
func (s *Source) FHIRFetchPage(ctx context.Context, url, tokenStr string) (any, error) {
var httpClient *http.Client
if s.UseClientAuthorization() {
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
if err != nil {
return nil, fmt.Errorf("failed to create default http client: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
}
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientEverything(storeID, patientID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", s.Project(), s.Region(), s.DatasetID(), storeID, patientID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) FHIRPatientSearch(storeID, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDataset(tokenStr string) (*healthcare.Dataset, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
return dataset, nil
}
func (s *Source) GetFHIRResource(storeID, resType, resID, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", s.Project(), s.Region(), s.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
}
defer resp.Body.Close()
return parseResults(resp)
}
func (s *Source) GetDICOMStore(storeID, tokenStr string) (*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStore(storeID, tokenStr string) (*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetDICOMStoreMetrics(storeID, tokenStr string) (*healthcare.DicomStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) GetFHIRStoreMetrics(storeID, tokenStr string) (*healthcare.FhirStoreMetrics, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
}
return store, nil
}
func (s *Source) ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(s.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) ListFHIRStores(tokenStr string) ([]*healthcare.FhirStore, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", s.Project(), s.Region(), s.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(s.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := s.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
}
func (s *Source) RetrieveRenderedDICOMInstance(storeID, study, series, sop string, frame int, tokenStr string) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
base64String := base64.StdEncoding.EncodeToString(respBytes)
return base64String, nil
}
func (s *Source) SearchDICOM(toolKind, storeID, dicomWebPath, tokenStr string, opts []googleapi.CallOption) (any, error) {
svc, err := s.getService(tokenStr)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", s.Project(), s.Region(), s.DatasetID(), storeID)
var resp *http.Response
switch toolKind {
case "cloud-healthcare-search-dicom-instances":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-series":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
case "cloud-healthcare-search-dicom-studies":
resp, err = svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, dicomWebPath).Do(opts...)
default:
return nil, fmt.Errorf("incompatible tool kind: %s", toolKind)
}
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal(respBytes, &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
}

View File

@@ -352,6 +352,28 @@ func (s *Source) GetWaitForOperations(ctx context.Context, service *sqladmin.Ser
return nil, nil
}
func (s *Source) InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error) {
backupRun := &sqladmin.BackupRun{}
if location != "" {
backupRun.Location = location
}
if backupDescription != "" {
backupRun.Description = backupDescription
}
service, err := s.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
resp, err := service.BackupRuns.Insert(project, instance, backupRun).Do()
if err != nil {
return nil, fmt.Errorf("error creating backup: %w", err)
}
return resp, nil
}
func generateCloudSQLConnectionMessage(ctx context.Context, source *Source, logger log.Logger, opResponse map[string]any, connectionMessageTemplate string) (string, bool) {
operationType, ok := opResponse["operationType"].(string)
if !ok || operationType != "CREATE_DATABASE" {

View File

@@ -19,6 +19,8 @@ import (
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/cenkalti/backoff/v5"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
@@ -121,3 +123,101 @@ func initDataplexConnection(
}
return client, nil
}
func (s *Source) LookupEntry(ctx context.Context, name string, view int, aspectTypes []string, entry string) (*dataplexpb.Entry, error) {
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
2: dataplexpb.EntryView_FULL,
3: dataplexpb.EntryView_CUSTOM,
4: dataplexpb.EntryView_ALL,
}
req := &dataplexpb.LookupEntryRequest{
Name: name,
View: viewMap[view],
AspectTypes: aspectTypes,
Entry: entry,
}
result, err := s.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
return result, nil
}
func (s *Source) searchRequest(ctx context.Context, query string, pageSize int, orderBy string) (*dataplexapi.SearchEntriesResultIterator, error) {
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", s.ProjectID()),
PageSize: int32(pageSize),
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := s.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", s.ProjectID())
}
return it, nil
}
func (s *Source) SearchAspectTypes(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.AspectType, error) {
q := query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype"
it, err := s.searchRequest(ctx, q, pageSize, orderBy)
if err != nil {
return nil, err
}
// Iterate through the search results and call GetAspectType for each result using the resource name
var results []*dataplexpb.AspectType
for {
entry, err := it.Next()
if err != nil {
break
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
getAspectBackOff := backoff.NewExponentialBackOff()
resourceName := entry.DataplexEntry.GetEntrySource().Resource
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
Name: resourceName,
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := s.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
return aspectType, nil
}
// Retry the GetAspectType operation with exponential backoff
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
if err != nil {
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
}
results = append(results, aspectType)
}
return results, nil
}
func (s *Source) SearchEntries(ctx context.Context, query string, pageSize int, orderBy string) ([]*dataplexpb.SearchEntriesResult, error) {
it, err := s.searchRequest(ctx, query, pageSize, orderBy)
if err != nil {
return nil, err
}
var results []*dataplexpb.SearchEntriesResult
for {
entry, err := it.Next()
if err != nil {
break
}
results = append(results, entry)
}
return results, nil
}

View File

@@ -16,7 +16,10 @@ package firestore
import (
"context"
"encoding/base64"
"fmt"
"strings"
"time"
"cloud.google.com/go/firestore"
"github.com/goccy/go-yaml"
@@ -25,6 +28,7 @@ import (
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/firebaserules/v1"
"google.golang.org/api/option"
"google.golang.org/genproto/googleapis/type/latlng"
)
const SourceKind string = "firestore"
@@ -113,6 +117,476 @@ func (s *Source) GetDatabaseId() string {
return s.Database
}
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
// This removes type information and returns plain values
func FirestoreValueToJSON(value any) any {
if value == nil {
return nil
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case *latlng.LatLng:
return map[string]any{
"latitude": v.Latitude,
"longitude": v.Longitude,
}
case []byte:
return base64.StdEncoding.EncodeToString(v)
case []any:
result := make([]any, len(v))
for i, item := range v {
result[i] = FirestoreValueToJSON(item)
}
return result
case map[string]any:
result := make(map[string]any)
for k, val := range v {
result[k] = FirestoreValueToJSON(val)
}
return result
case *firestore.DocumentRef:
return v.Path
default:
return value
}
}
// BuildQuery constructs the Firestore query from parameters
func (s *Source) BuildQuery(collectionPath string, filter firestore.EntityFilter, selectFields []string, field string, direction firestore.Direction, limit int, analyzeQuery bool) (*firestore.Query, error) {
collection := s.FirestoreClient().Collection(collectionPath)
query := collection.Query
// Process and apply filters if template is provided
if filter != nil {
query = query.WhereEntity(filter)
}
if len(selectFields) > 0 {
query = query.Select(selectFields...)
}
if field != "" {
query = query.OrderBy(field, direction)
}
query = query.Limit(limit)
// Apply analyze options if enabled
if analyzeQuery {
query = query.WithRunOptions(firestore.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime any `json:"createTime,omitempty"`
UpdateTime any `json:"updateTime,omitempty"`
ReadTime any `json:"readTime,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// ExecuteQuery runs the query and formats the results
func (s *Source) ExecuteQuery(ctx context.Context, query *firestore.Query, analyzeQuery bool) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if analyzeQuery {
explainMetrics, err := getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
return results, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func getExplainMetrics(docIterator *firestore.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
func (s *Source) GetDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create document references from paths
docRefs := make([]*firestore.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = s.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := s.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
// Convert snapshots to response data
results := make([]any, len(snapshots))
for i, snapshot := range snapshots {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
docData["exists"] = snapshot.Exists()
if snapshot.Exists() {
docData["data"] = snapshot.Data()
docData["createTime"] = snapshot.CreateTime
docData["updateTime"] = snapshot.UpdateTime
docData["readTime"] = snapshot.ReadTime
}
results[i] = docData
}
return results, nil
}
func (s *Source) AddDocuments(ctx context.Context, collectionPath string, documentData any, returnData bool) (map[string]any, error) {
// Get the collection reference
collection := s.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
if err != nil {
return nil, fmt.Errorf("failed to add document: %w", err)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data back to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) UpdateDocument(ctx context.Context, documentPath string, updates []firestore.Update, documentData any, returnData bool) (map[string]any, error) {
// Get the document reference
docRef := s.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestore.WriteResult
var writeErr error
if len(updates) > 0 {
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
writeResult, writeErr = docRef.Set(ctx, documentData, firestore.MergeAll)
}
if writeErr != nil {
return nil, fmt.Errorf("failed to update document: %w", writeErr)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data to simple JSON format
simplifiedData := FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
}
func (s *Source) DeleteDocuments(ctx context.Context, documentPaths []string) ([]any, error) {
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := s.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestore.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := s.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
}
jobs[i] = job
}
// End the BulkWriter to execute all operations
bulkWriter.End()
// Collect results
results := make([]any, len(documentPaths))
for i, job := range jobs {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
// Wait for the job to complete and get the result
_, err := job.Results()
if err != nil {
docData["success"] = false
docData["error"] = err.Error()
} else {
docData["success"] = true
}
results[i] = docData
}
return results, nil
}
func (s *Source) ListCollections(ctx context.Context, parentPath string) ([]any, error) {
var collectionRefs []*firestore.CollectionRef
var err error
if parentPath != "" {
// List subcollections of the specified document
docRef := s.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = s.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
}
// Convert collection references to response data
results := make([]any, len(collectionRefs))
for i, collRef := range collectionRefs {
collData := make(map[string]any)
collData["id"] = collRef.ID
collData["path"] = collRef.Path
// If this is a subcollection, include parent information
if collRef.Parent != nil {
collData["parent"] = collRef.Parent.Path
}
results[i] = collData
}
return results, nil
}
func (s *Source) GetRules(ctx context.Context) (any, error) {
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", s.GetProjectId(), s.GetDatabaseId())
release, err := s.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", s.GetProjectId(), s.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := s.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
return nil, fmt.Errorf("no rules files found in ruleset")
}
return ruleset, nil
}
// SourcePosition represents the location of an issue in the source
type SourcePosition struct {
FileName string `json:"fileName,omitempty"`
Line int64 `json:"line"` // 1-based
Column int64 `json:"column"` // 1-based
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
}
// Issue represents a validation issue in the rules
type Issue struct {
SourcePosition SourcePosition `json:"sourcePosition"`
Description string `json:"description"`
Severity string `json:"severity"`
}
// ValidationResult represents the result of rules validation
type ValidationResult struct {
Valid bool `json:"valid"`
IssueCount int `json:"issueCount"`
FormattedIssues string `json:"formattedIssues,omitempty"`
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (s *Source) ValidateRules(ctx context.Context, sourceParam string) (any, error) {
// Create test request
testRequest := &firebaserules.TestRulesetRequest{
Source: &firebaserules.Source{
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: sourceParam,
},
},
},
// We don't need test cases for validation only
TestSuite: &firebaserules.TestSuite{
TestCases: []*firebaserules.TestCase{},
},
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", s.GetProjectId())
response, err := s.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
if len(response.Issues) == 0 {
return ValidationResult{
Valid: true,
IssueCount: 0,
FormattedIssues: "✓ No errors detected. Rules are valid.",
}, nil
}
// Convert issues to our format
issues := make([]Issue, len(response.Issues))
for i, issue := range response.Issues {
issues[i] = Issue{
Description: issue.Description,
Severity: issue.Severity,
SourcePosition: SourcePosition{
FileName: issue.SourcePosition.FileName,
Line: issue.SourcePosition.Line,
Column: issue.SourcePosition.Column,
CurrentOffset: issue.SourcePosition.CurrentOffset,
EndOffset: issue.SourcePosition.EndOffset,
},
}
}
// Format issues
sourceLines := strings.Split(sourceParam, "\n")
var formattedOutput []string
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
for _, issue := range issues {
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
issue.Severity,
issue.Description,
issue.SourcePosition.Line,
issue.SourcePosition.Column)
if issue.SourcePosition.Line > 0 {
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
if lineIndex >= 0 && lineIndex < len(sourceLines) {
errorLine := sourceLines[lineIndex]
issueString += fmt.Sprintf("\n```\n%s", errorLine)
// Add carets if we have column and offset information
if issue.SourcePosition.Column > 0 &&
issue.SourcePosition.CurrentOffset >= 0 &&
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
padding := strings.Repeat(" ", startColumn)
carets := strings.Repeat("^", errorTokenLength)
issueString += fmt.Sprintf("\n%s%s", padding, carets)
}
}
issueString += "\n```"
}
}
formattedOutput = append(formattedOutput, issueString)
}
formattedIssues := strings.Join(formattedOutput, "\n\n")
return ValidationResult{
Valid: false,
IssueCount: len(issues),
FormattedIssues: formattedIssues,
RawIssues: issues,
}, nil
}
func initFirestoreConnection(
ctx context.Context,
tracer trace.Tracer,

View File

@@ -16,6 +16,7 @@ package firestore_test
import (
"testing"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
@@ -128,3 +129,37 @@ func TestFailParseFromYamlFirestore(t *testing.T) {
})
}
}
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
// Test round-trip conversion
original := map[string]any{
"name": "Test",
"count": int64(42),
"price": 19.99,
"active": true,
"tags": []any{"tag1", "tag2"},
"metadata": map[string]any{
"created": time.Now(),
},
"nullField": nil,
}
// Convert to JSON representation
jsonRepresentation := firestore.FirestoreValueToJSON(original)
// Verify types are simplified
jsonMap, ok := jsonRepresentation.(map[string]any)
if !ok {
t.Fatalf("Expected map, got %T", jsonRepresentation)
}
// Time should be converted to string
metadata, ok := jsonMap["metadata"].(map[string]any)
if !ok {
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
}
_, ok = metadata["created"].(string)
if !ok {
t.Errorf("created should be a string, got %T", metadata["created"])
}
}

View File

@@ -16,7 +16,9 @@ package http
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
@@ -143,3 +145,28 @@ func (s *Source) HttpQueryParams() map[string]string {
func (s *Source) Client() *http.Client {
return s.client
}
func (s *Source) RunRequest(req *http.Request) (any, error) {
// Make request and fetch response
resp, err := s.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var data any
if err = json.Unmarshal(body, &data); err != nil {
// if unable to unmarshal data, return result as string.
return string(body), nil
}
return data, nil
}

View File

@@ -15,7 +15,9 @@ package looker
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"strings"
"time"
@@ -208,6 +210,49 @@ func (s *Source) LookerSessionLength() int64 {
return s.SessionLength
}
// Make types for RoundTripper
type transportWithAuthHeader struct {
Base http.RoundTripper
AuthToken string
}
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("x-looker-appid", "go-sdk")
req.Header.Set("Authorization", t.AuthToken)
return t.Base.RoundTrip(req)
}
func (s *Source) GetLookerSDK(accessToken string) (*v4.LookerSDK, error) {
if s.UseClientAuthorization() {
if accessToken == "" {
return nil, fmt.Errorf("no access token supplied with request")
}
session := rtl.NewAuthSession(*s.LookerApiSettings())
// Configure base transport with TLS
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: !s.LookerApiSettings().VerifySsl,
},
}
// Build transport for end user token
session.Client = http.Client{
Transport: &transportWithAuthHeader{
Base: transport,
AuthToken: accessToken,
},
}
// return SDK with new Transport
return v4.NewLookerSDK(session), nil
}
if s.LookerClient() == nil {
return nil, fmt.Errorf("client id or client secret not valid")
}
return s.LookerClient(), nil
}
func initGoogleCloudConnection(ctx context.Context) (oauth2.TokenSource, error) {
cred, err := google.FindDefaultCredentials(ctx, geminidataanalytics.DefaultAuthScopes()...)
if err != nil {

View File

@@ -16,11 +16,14 @@ package mongodb
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.opentelemetry.io/otel/trace"
@@ -93,6 +96,201 @@ func (s *Source) MongoClient() *mongo.Client {
return s.Client
}
func parseData(ctx context.Context, cur *mongo.Cursor) ([]any, error) {
var data = []any{}
err := cur.All(ctx, &data)
if err != nil {
return nil, err
}
var final []any
for _, item := range data {
tmp, _ := bson.MarshalExtJSON(item, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
}
return final, err
}
func (s *Source) Aggregate(ctx context.Context, pipelineString string, canonical, readOnly bool, database, collection string) ([]any, error) {
var pipeline = []bson.M{}
err := bson.UnmarshalExtJSON([]byte(pipelineString), canonical, &pipeline)
if err != nil {
return nil, err
}
if readOnly {
//fail if we do a merge or an out
for _, stage := range pipeline {
for key := range stage {
if key == "$merge" || key == "$out" {
return nil, fmt.Errorf("this is not a read-only pipeline: %+v", stage)
}
}
}
}
cur, err := s.MongoClient().Database(database).Collection(collection).Aggregate(ctx, pipeline)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
res, err := parseData(ctx, cur)
if err != nil {
return nil, err
}
if res == nil {
return []any{}, nil
}
return res, err
}
func (s *Source) Find(ctx context.Context, filterString, database, collection string, opts *options.FindOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
cur, err := s.MongoClient().Database(database).Collection(collection).Find(ctx, filter, opts)
if err != nil {
return nil, err
}
defer cur.Close(ctx)
return parseData(ctx, cur)
}
func (s *Source) FindOne(ctx context.Context, filterString, database, collection string, opts *options.FindOneOptions) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res := s.MongoClient().Database(database).Collection(collection).FindOne(ctx, filter, opts)
if res.Err() != nil {
return nil, res.Err()
}
var data any
err = res.Decode(&data)
if err != nil {
return nil, err
}
var final []any
tmp, _ := bson.MarshalExtJSON(data, false, false)
var tmp2 any
err = json.Unmarshal(tmp, &tmp2)
if err != nil {
return nil, err
}
final = append(final, tmp2)
return final, err
}
func (s *Source) InsertMany(ctx context.Context, jsonData string, canonical bool, database, collection string) ([]any, error) {
var data = []any{}
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).InsertMany(ctx, data, options.InsertMany())
if err != nil {
return nil, err
}
return res.InsertedIDs, nil
}
func (s *Source) InsertOne(ctx context.Context, jsonData string, canonical bool, database, collection string) (any, error) {
var data any
err := bson.UnmarshalExtJSON([]byte(jsonData), canonical, &data)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).InsertOne(ctx, data, options.InsertOne())
if err != nil {
return nil, err
}
return res.InsertedID, nil
}
func (s *Source) UpdateMany(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) ([]any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), canonical, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), false, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}
res, err := s.MongoClient().Database(database).Collection(collection).UpdateMany(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return []any{res.ModifiedCount, res.UpsertedCount, res.MatchedCount}, nil
}
func (s *Source) UpdateOne(ctx context.Context, filterString string, canonical bool, updateString, database, collection string, upsert bool) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal filter string: %w", err)
}
var update = bson.D{}
err = bson.UnmarshalExtJSON([]byte(updateString), canonical, &update)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal update string: %w", err)
}
res, err := s.MongoClient().Database(database).Collection(collection).UpdateOne(ctx, filter, update, options.Update().SetUpsert(upsert))
if err != nil {
return nil, fmt.Errorf("error updating collection: %w", err)
}
return res.ModifiedCount, nil
}
func (s *Source) DeleteMany(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).DeleteMany(ctx, filter, options.Delete())
if err != nil {
return nil, err
}
if res.DeletedCount == 0 {
return nil, errors.New("no document found")
}
return res.DeletedCount, nil
}
func (s *Source) DeleteOne(ctx context.Context, filterString, database, collection string) (any, error) {
var filter = bson.D{}
err := bson.UnmarshalExtJSON([]byte(filterString), false, &filter)
if err != nil {
return nil, err
}
res, err := s.MongoClient().Database(database).Collection(collection).DeleteOne(ctx, filter, options.Delete())
if err != nil {
return nil, err
}
return res.DeletedCount, nil
}
func initMongoDBClient(ctx context.Context, tracer trace.Tracer, name, uri string) (*mongo.Client, error) {
// Start a tracing span
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)

View File

@@ -16,15 +16,21 @@ package serverlessspark
import (
"context"
"encoding/json"
"fmt"
"time"
dataproc "cloud.google.com/go/dataproc/v2/apiv1"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
longrunning "cloud.google.com/go/longrunning/autogen"
"cloud.google.com/go/longrunning/autogen/longrunningpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/protobuf/encoding/protojson"
)
const SourceKind string = "serverless-spark"
@@ -121,3 +127,168 @@ func (s *Source) Close() error {
}
return nil
}
func (s *Source) CancelOperation(ctx context.Context, operation string) (any, error) {
req := &longrunningpb.CancelOperationRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/operations/%s", s.GetProject(), s.GetLocation(), operation),
}
client, err := s.GetOperationsClient(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get operations client: %w", err)
}
err = client.CancelOperation(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to cancel operation: %w", err)
}
return fmt.Sprintf("Cancelled [%s].", operation), nil
}
func (s *Source) CreateBatch(ctx context.Context, batch *dataprocpb.Batch) (map[string]any, error) {
req := &dataprocpb.CreateBatchRequest{
Parent: fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation()),
Batch: batch,
}
client := s.GetBatchControllerClient()
op, err := client.CreateBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to create batch: %w", err)
}
meta, err := op.Metadata()
if err != nil {
return nil, fmt.Errorf("failed to get create batch op metadata: %w", err)
}
projectID, location, batchID, err := ExtractBatchDetails(meta.GetBatch())
if err != nil {
return nil, fmt.Errorf("error extracting batch details from name %q: %v", meta.GetBatch(), err)
}
consoleUrl := BatchConsoleURL(projectID, location, batchID)
logsUrl := BatchLogsURL(projectID, location, batchID, meta.GetCreateTime().AsTime(), time.Time{})
wrappedResult := map[string]any{
"opMetadata": meta,
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
}
return wrappedResult, nil
}
// ListBatchesResponse is the response from the list batches API.
type ListBatchesResponse struct {
Batches []Batch `json:"batches"`
NextPageToken string `json:"nextPageToken"`
}
// Batch represents a single batch job.
type Batch struct {
Name string `json:"name"`
UUID string `json:"uuid"`
State string `json:"state"`
Creator string `json:"creator"`
CreateTime string `json:"createTime"`
Operation string `json:"operation"`
ConsoleURL string `json:"consoleUrl"`
LogsURL string `json:"logsUrl"`
}
func (s *Source) ListBatches(ctx context.Context, ps *int, pt, filter string) (any, error) {
client := s.GetBatchControllerClient()
parent := fmt.Sprintf("projects/%s/locations/%s", s.GetProject(), s.GetLocation())
req := &dataprocpb.ListBatchesRequest{
Parent: parent,
OrderBy: "create_time desc",
}
if ps != nil {
req.PageSize = int32(*ps)
}
if pt != "" {
req.PageToken = pt
}
if filter != "" {
req.Filter = filter
}
it := client.ListBatches(ctx, req)
pager := iterator.NewPager(it, int(req.PageSize), req.PageToken)
var batchPbs []*dataprocpb.Batch
nextPageToken, err := pager.NextPage(&batchPbs)
if err != nil {
return nil, fmt.Errorf("failed to list batches: %w", err)
}
batches, err := ToBatches(batchPbs)
if err != nil {
return nil, err
}
return ListBatchesResponse{Batches: batches, NextPageToken: nextPageToken}, nil
}
// ToBatches converts a slice of protobuf Batch messages to a slice of Batch structs.
func ToBatches(batchPbs []*dataprocpb.Batch) ([]Batch, error) {
batches := make([]Batch, 0, len(batchPbs))
for _, batchPb := range batchPbs {
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
batch := Batch{
Name: batchPb.Name,
UUID: batchPb.Uuid,
State: batchPb.State.Enum().String(),
Creator: batchPb.Creator,
CreateTime: batchPb.CreateTime.AsTime().Format(time.RFC3339),
Operation: batchPb.Operation,
ConsoleURL: consoleUrl,
LogsURL: logsUrl,
}
batches = append(batches, batch)
}
return batches, nil
}
func (s *Source) GetBatch(ctx context.Context, name string) (map[string]any, error) {
client := s.GetBatchControllerClient()
req := &dataprocpb.GetBatchRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/batches/%s", s.GetProject(), s.GetLocation(), name),
}
batchPb, err := client.GetBatch(ctx, req)
if err != nil {
return nil, fmt.Errorf("failed to get batch: %w", err)
}
jsonBytes, err := protojson.Marshal(batchPb)
if err != nil {
return nil, fmt.Errorf("failed to marshal batch to JSON: %w", err)
}
var result map[string]any
if err := json.Unmarshal(jsonBytes, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal batch JSON: %w", err)
}
consoleUrl, err := BatchConsoleURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating console url: %v", err)
}
logsUrl, err := BatchLogsURLFromProto(batchPb)
if err != nil {
return nil, fmt.Errorf("error generating logs url: %v", err)
}
wrappedResult := map[string]any{
"consoleUrl": consoleUrl,
"logsUrl": logsUrl,
"batch": result,
}
return wrappedResult, nil
}

View File

@@ -1,10 +1,10 @@
// Copyright 2025 Google LLC
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package common
package serverlessspark
import (
"fmt"
@@ -23,13 +23,13 @@ import (
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
)
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
const (
logTimeBufferBefore = 1 * time.Minute
logTimeBufferAfter = 10 * time.Minute
)
var batchFullNameRegex = regexp.MustCompile(`projects/(?P<project>[^/]+)/locations/(?P<location>[^/]+)/batches/(?P<batch_id>[^/]+)`)
// Extract BatchDetails extracts the project ID, location, and batch ID from a fully qualified batch name.
func ExtractBatchDetails(batchName string) (projectID, location, batchID string, err error) {
matches := batchFullNameRegex.FindStringSubmatch(batchName)
@@ -39,26 +39,6 @@ func ExtractBatchDetails(batchName string) (projectID, location, batchID string,
return matches[1], matches[2], matches[3], nil
}
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
return BatchConsoleURL(projectID, location, batchID), nil
}
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
createTime := batchPb.GetCreateTime().AsTime()
stateTime := batchPb.GetStateTime().AsTime()
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
}
// BatchConsoleURL builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURL(projectID, location, batchID string) string {
return fmt.Sprintf("https://console.cloud.google.com/dataproc/batches/%s/%s/summary?project=%s", location, batchID, projectID)
@@ -89,3 +69,23 @@ resource.labels.batch_id="%s"`
return "https://console.cloud.google.com/logs/viewer?" + v.Encode()
}
// BatchConsoleURLFromProto builds a URL to the Google Cloud Console linking to the batch summary page.
func BatchConsoleURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
return BatchConsoleURL(projectID, location, batchID), nil
}
// BatchLogsURLFromProto builds a URL to the Google Cloud Console showing Cloud Logging for the given batch and time range.
func BatchLogsURLFromProto(batchPb *dataprocpb.Batch) (string, error) {
projectID, location, batchID, err := ExtractBatchDetails(batchPb.GetName())
if err != nil {
return "", err
}
createTime := batchPb.GetCreateTime().AsTime()
stateTime := batchPb.GetStateTime().AsTime()
return BatchLogsURL(projectID, location, batchID, createTime, stateTime), nil
}

View File

@@ -1,10 +1,10 @@
// Copyright 2025 Google LLC
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
@@ -12,19 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package common
package serverlessspark_test
import (
"testing"
"time"
"cloud.google.com/go/dataproc/v2/apiv1/dataprocpb"
"github.com/googleapis/genai-toolbox/internal/sources/serverlessspark"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestExtractBatchDetails_Success(t *testing.T) {
batchName := "projects/my-project/locations/us-central1/batches/my-batch"
projectID, location, batchID, err := ExtractBatchDetails(batchName)
projectID, location, batchID, err := serverlessspark.ExtractBatchDetails(batchName)
if err != nil {
t.Errorf("ExtractBatchDetails() error = %v, want no error", err)
return
@@ -45,7 +46,7 @@ func TestExtractBatchDetails_Success(t *testing.T) {
func TestExtractBatchDetails_Failure(t *testing.T) {
batchName := "invalid-name"
_, _, _, err := ExtractBatchDetails(batchName)
_, _, _, err := serverlessspark.ExtractBatchDetails(batchName)
wantErr := "failed to parse batch name: invalid-name"
if err == nil || err.Error() != wantErr {
t.Errorf("ExtractBatchDetails() error = %v, want %v", err, wantErr)
@@ -53,7 +54,7 @@ func TestExtractBatchDetails_Failure(t *testing.T) {
}
func TestBatchConsoleURL(t *testing.T) {
got := BatchConsoleURL("my-project", "us-central1", "my-batch")
got := serverlessspark.BatchConsoleURL("my-project", "us-central1", "my-batch")
want := "https://console.cloud.google.com/dataproc/batches/us-central1/my-batch/summary?project=my-project"
if got != want {
t.Errorf("BatchConsoleURL() = %v, want %v", got, want)
@@ -63,7 +64,7 @@ func TestBatchConsoleURL(t *testing.T) {
func TestBatchLogsURL(t *testing.T) {
startTime := time.Date(2025, 10, 1, 5, 0, 0, 0, time.UTC)
endTime := time.Date(2025, 10, 1, 6, 0, 0, 0, time.UTC)
got := BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
got := serverlessspark.BatchLogsURL("my-project", "us-central1", "my-batch", startTime, endTime)
want := "https://console.cloud.google.com/logs/viewer?advancedFilter=" +
"resource.type%3D%22cloud_dataproc_batch%22" +
"%0Aresource.labels.project_id%3D%22my-project%22" +
@@ -82,7 +83,7 @@ func TestBatchConsoleURLFromProto(t *testing.T) {
batchPb := &dataprocpb.Batch{
Name: "projects/my-project/locations/us-central1/batches/my-batch",
}
got, err := BatchConsoleURLFromProto(batchPb)
got, err := serverlessspark.BatchConsoleURLFromProto(batchPb)
if err != nil {
t.Fatalf("BatchConsoleURLFromProto() error = %v", err)
}
@@ -100,7 +101,7 @@ func TestBatchLogsURLFromProto(t *testing.T) {
CreateTime: timestamppb.New(createTime),
StateTime: timestamppb.New(stateTime),
}
got, err := BatchLogsURLFromProto(batchPb)
got, err := serverlessspark.BatchLogsURLFromProto(batchPb)
if err != nil {
t.Fatalf("BatchLogsURLFromProto() error = %v", err)
}

View File

@@ -19,15 +19,32 @@ import (
"encoding/json"
"fmt"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind string = "cloud-gemini-data-analytics-query"
// Guidance is the tool guidance string.
const Guidance = `Tool guidance:
Inputs:
1. query: A natural language formulation of a database query.
Outputs: (all optional)
1. disambiguation_question: Clarification questions or comments where the tool needs the users' input.
2. generated_query: The generated query for the user query.
3. intent_explanation: An explanation for why the tool produced ` + "`generated_query`" + `.
4. query_result: The result of executing ` + "`generated_query`" + `.
5. natural_language_answer: The natural language answer that summarizes the ` + "`query`" + ` and ` + "`query_result`" + `.
Usage guidance:
1. If ` + "`disambiguation_question`" + ` is produced, then solicit the needed inputs from the user and try the tool with a new ` + "`query`" + ` that has the needed clarification.
2. If ` + "`natural_language_answer`" + ` is produced, use ` + "`intent_explanation`" + ` and ` + "`generated_query`" + ` to see if you need to clarify any assumptions for the user.`
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
@@ -45,7 +62,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, []byte) (any, error)
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
}
type Config struct {
@@ -68,19 +127,28 @@ func (cfg Config) ToolConfigKind() string {
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Define the parameters for the Gemini Data Analytics Query API
// The prompt is the only input parameter.
// The query is the only input parameter.
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
parameters.NewStringParameterWithRequired("query", "A natural language formulation of a database query.", true),
}
// The input and outputs are for tool guidance, usage guidance is for multi-turn interaction.
guidance := Guidance
if cfg.Description != "" {
cfg.Description += "\n\n" + guidance
} else {
cfg.Description = guidance
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
return Tool{
t := Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
return t, nil
}
// validate interface
@@ -105,9 +173,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
paramsMap := params.AsMap()
prompt, ok := paramsMap["prompt"].(string)
query, ok := paramsMap["query"].(string)
if !ok {
return nil, fmt.Errorf("prompt parameter not found or not a string")
return nil, fmt.Errorf("query parameter not found or not a string")
}
// Parse the access token if provided
@@ -123,18 +191,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: prompt,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
if t.Context != nil {
req.Context = t.Context.QueryDataContext
}
return source.RunQuery(ctx, tokenStr, bodyBytes)
if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}
return source.RunQuery(ctx, tokenStr, req)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,19 +16,16 @@ package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
@@ -74,23 +71,29 @@ func TestParseFromYaml(t *testing.T) {
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
},
},
@@ -108,68 +111,63 @@ func TestParseFromYaml(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
// fakeSource implements the compatibleSource interface for testing.
type fakeSource struct {
projectID string
useClientOAuth bool
expectedQuery string
expectedParent string
response *geminidataanalyticspb.QueryDataResponse
}
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
return rt.Next.RoundTrip(&newReq)
func (f *fakeSource) GetProjectID() string {
return f.projectID
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
func (f *fakeSource) UseClientAuthorization() bool {
return f.useClientOAuth
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
func (f *fakeSource) SourceKind() string {
return "fake-gda-source"
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
func (f *fakeSource) ToConfig() sources.SourceConfig {
return nil
}
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return f, nil
}
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
if req.Prompt != f.expectedQuery {
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
}
if req.Parent != f.expectedParent {
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
}
// Basic validation of context/options could be added here if needed,
// but the test case mainly checks if they are passed correctly via successful invocation.
return f.response, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
// Minimal fake source
fake := &fakeSource{projectID: "test-project"}
srcs := map[string]sources.Source{
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
"gda-api-source": fake,
}
tcs := []struct {
@@ -188,9 +186,6 @@ func TestInitialize(t *testing.T) {
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
@@ -207,92 +202,27 @@ func TestInitialize(t *testing.T) {
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
projectID := "test-project"
location := "us-central1"
query := "How many accounts who have region in Prague are eligible for loans?"
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
// Prepare expected response
expectedResp := &geminidataanalyticspb.QueryDataResponse{
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
}
authClient := &http.Client{Transport: authTransport}
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
fake := &fakeSource{
projectID: projectID,
expectedQuery: query,
expectedParent: expectedParent,
response: expectedResp,
}
srcs := map[string]sources.Source{
"mock-gda-source": mockGdaSource,
"mock-gda-source": fake,
}
// Initialize the tool config with context
@@ -301,25 +231,31 @@ func TestInvoke(t *testing.T) {
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: "us-central1", // Set location for the test
Location: location,
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
}
@@ -328,26 +264,27 @@ func TestInvoke(t *testing.T) {
t.Fatalf("failed to initialize tool: %v", err)
}
// Prepare parameters for invocation - ONLY prompt
// Prepare parameters for invocation - ONLY query
params := parameters.ParamValues{
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
{Name: "query", Value: query},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
result, err := tool.Invoke(ctx, resourceMgr, params, "")
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
if !ok {
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
}
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -1,116 +0,0 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -16,22 +16,13 @@ package fhirfetchpage
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
"net/http"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const kind string = "cloud-healthcare-fhir-fetch-page"
@@ -54,13 +45,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRFetchPage(context.Context, string, string) (any, error)
}
type Config struct {
@@ -118,48 +104,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", pageURLKey)
}
var httpClient *http.Client
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: tokenStr})
httpClient = oauth2.NewClient(ctx, ts)
} else {
// The source.Service() object holds a client with the default credentials.
// However, the client is not exported, so we have to create a new one.
var err error
httpClient, err = google.DefaultClient(ctx, healthcare.CloudHealthcareScope)
if err != nil {
return nil, fmt.Errorf("failed to create default http client: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to create http request: %w", err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
req.Header.Set("Accept", "application/fhir+json;charset=utf-8")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to get fhir page from %q: %w", url, err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
return source.FHIRFetchPage(ctx, url, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,20 +16,16 @@ package fhirpatienteverything
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-fhir-patient-everything"
@@ -54,13 +50,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRPatientEverything(string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -139,20 +131,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", patientIDKey)
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/Patient/%s", source.Project(), source.Region(), source.DatasetID(), storeID, patientID)
var opts []googleapi.CallOption
if val, ok := params.AsMap()[typeFilterKey]; ok {
types, ok := val.([]any)
@@ -176,25 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
opts = append(opts, googleapi.QueryParameter("_since", sinceStr))
}
}
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.PatientEverything(name).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to call patient everything for %q: %w", name, err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("patient-everything: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
return source.FHIRPatientEverything(storeID, patientID, tokenStr, opts)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,20 +16,16 @@ package fhirpatientsearch
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-fhir-patient-search"
@@ -70,13 +66,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
FHIRPatientSearch(string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -169,17 +161,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
var summary bool
@@ -248,26 +232,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if summary {
opts = append(opts, googleapi.QueryParameter("_summary", "text"))
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.FhirStores.Fhir.SearchType(name, "Patient", &healthcare.SearchResourcesRequest{ResourceType: "Patient"}).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search patient resources: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
return source.FHIRPatientSearch(storeID, tokenStr, opts)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,7 +21,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -44,12 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDataset(string) (*healthcare.Dataset, error)
}
type Config struct {
@@ -100,27 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
dataset, err := svc.Projects.Locations.Datasets.Get(datasetName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
return dataset, nil
return source.GetDataset(tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,7 +21,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDICOMStore(string, string) (*healthcare.DicomStore, error)
}
type Config struct {
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.Get(storeName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get DICOM store %q: %w", storeName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
return store, nil
return source.GetDICOMStore(storeID, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,7 +21,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetDICOMStoreMetrics(string, string) (*healthcare.DicomStoreMetrics, error)
}
type Config struct {
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.DicomStores.GetDICOMStoreMetrics(storeName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for DICOM store %q: %w", storeName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
return store, nil
return source.GetDICOMStoreMetrics(storeID, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,18 +16,14 @@ package getfhirresource
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-get-fhir-resource"
@@ -51,13 +47,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRResource(string, string, string, string) (any, error)
}
type Config struct {
@@ -134,46 +126,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", typeKey)
}
resID, ok := params.AsMap()[idKey].(string)
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected a string", idKey)
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s/fhir/%s/%s", source.Project(), source.Region(), source.DatasetID(), storeID, resType, resID)
call := svc.Projects.Locations.Datasets.FhirStores.Fhir.Read(name)
call.Header().Set("Content-Type", "application/fhir+json;charset=utf-8")
resp, err := call.Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get fhir resource %q: %w", name, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("read: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
var jsonMap map[string]interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &jsonMap); err != nil {
return nil, fmt.Errorf("could not unmarshal response as json: %w", err)
}
return jsonMap, nil
return source.GetFHIRResource(storeID, resType, resID, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,7 +21,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRStore(string, string) (*healthcare.FhirStore, error)
}
type Config struct {
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.Get(storeName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get FHIR store %q: %w", storeName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
return store, nil
return source.GetFHIRStore(storeID, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -21,7 +21,6 @@ import (
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -45,13 +44,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
GetFHIRStoreMetrics(string, string) (*healthcare.FhirStoreMetrics, error)
}
type Config struct {
@@ -117,31 +112,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedFHIRStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
storeName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/fhirStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
store, err := svc.Projects.Locations.Datasets.FhirStores.GetFHIRStoreMetrics(storeName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get metrics for FHIR store %q: %w", storeName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
return store, nil
return source.GetFHIRStoreMetrics(storeID, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -17,12 +17,10 @@ package listdicomstores
import (
"context"
"fmt"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
ListDICOMStores(tokenStr string) ([]*healthcare.DicomStore, error)
}
type Config struct {
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.DicomStores.List(datasetName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
var filtered []*healthcare.DicomStore
for _, store := range stores.DicomStores {
if len(source.AllowedDICOMStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := source.AllowedDICOMStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
return source.ListDICOMStores(tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -17,12 +17,10 @@ package listfhirstores
import (
"context"
"fmt"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
@@ -45,13 +43,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedFHIRStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
ListFHIRStores(string) ([]*healthcare.FhirStore, error)
}
type Config struct {
@@ -102,41 +95,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
}
datasetName := fmt.Sprintf("projects/%s/locations/%s/datasets/%s", source.Project(), source.Region(), source.DatasetID())
stores, err := svc.Projects.Locations.Datasets.FhirStores.List(datasetName).Do()
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("failed to get dataset %q: %w", datasetName, err)
return nil, fmt.Errorf("error parsing access token: %w", err)
}
var filtered []*healthcare.FhirStore
for _, store := range stores.FhirStores {
if len(source.AllowedFHIRStores()) == 0 {
filtered = append(filtered, store)
continue
}
if len(store.Name) == 0 {
continue
}
parts := strings.Split(store.Name, "/")
if _, ok := source.AllowedFHIRStores()[parts[len(parts)-1]]; ok {
filtered = append(filtered, store)
}
}
return filtered, nil
return source.ListFHIRStores(tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,18 +16,14 @@ package retrieverendereddicominstance
import (
"context"
"encoding/base64"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-retrieve-rendered-dicom-instance"
@@ -53,13 +49,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
RetrieveRenderedDICOMInstance(string, string, string, string, int, string) (any, error)
}
type Config struct {
@@ -135,20 +127,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
study, ok := params.AsMap()[studyInstanceUIDKey].(string)
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", studyInstanceUIDKey)
@@ -165,25 +147,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid '%s' parameter; expected an integer", frameNumberKey)
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
dicomWebPath := fmt.Sprintf("studies/%s/series/%s/instances/%s/frames/%d/rendered", study, series, sop, frame)
call := svc.Projects.Locations.Datasets.DicomStores.Studies.Series.Instances.Frames.RetrieveRendered(name, dicomWebPath)
call.Header().Set("Accept", "image/jpeg")
resp, err := call.Do()
if err != nil {
return nil, fmt.Errorf("unable to retrieve dicom instance rendered image: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("RetrieveRendered: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
base64String := base64.StdEncoding.EncodeToString(respBytes)
return base64String, nil
return source.RetrieveRenderedDICOMInstance(storeID, study, series, sop, frame, tokenStr)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,20 +16,16 @@ package searchdicominstances
import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/googleapi"
"google.golang.org/api/healthcare/v1"
)
const kind string = "cloud-healthcare-search-dicom-instances"
@@ -60,13 +56,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -144,23 +136,13 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
opts, err := common.ParseDICOMSearchParameters(params, []string{sopInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
@@ -191,29 +173,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForInstances(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom instances: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,18 +16,15 @@ package searchdicomseries
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
"google.golang.org/api/googleapi"
)
const kind string = "cloud-healthcare-search-dicom-series"
@@ -57,13 +54,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -145,18 +138,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
opts, err := common.ParseDICOMSearchParameters(params, []string{seriesInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey, modalityKey})
@@ -174,29 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
dicomWebPath = fmt.Sprintf("studies/%s/series", id)
}
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForSeries(name, dicomWebPath).Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom series: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,18 +16,15 @@ package searchdicomstudies
import (
"context"
"encoding/json"
"fmt"
"io"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
healthcareds "github.com/googleapis/genai-toolbox/internal/sources/cloudhealthcare"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/common"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/healthcare/v1"
"google.golang.org/api/googleapi"
)
const kind string = "cloud-healthcare-search-dicom-studies"
@@ -55,13 +52,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
Project() string
Region() string
DatasetID() string
AllowedDICOMStores() map[string]struct{}
Service() *healthcare.Service
ServiceCreator() healthcareds.HealthcareServiceCreator
UseClientAuthorization() bool
SearchDICOM(string, string, string, string, []googleapi.CallOption) (any, error)
}
type Config struct {
@@ -136,51 +129,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
storeID, err := common.ValidateAndFetchStoreID(params, source.AllowedDICOMStores())
if err != nil {
return nil, err
}
svc := source.Service()
// Initialize new service if using user OAuth token
if source.UseClientAuthorization() {
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
svc, err = source.ServiceCreator()(tokenStr)
if err != nil {
return nil, fmt.Errorf("error creating service from OAuth access token: %w", err)
}
tokenStr, err := accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
opts, err := common.ParseDICOMSearchParameters(params, []string{studyInstanceUIDKey, patientNameKey, patientIDKey, accessionNumberKey, referringPhysicianNameKey, studyDateKey})
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/datasets/%s/dicomStores/%s", source.Project(), source.Region(), source.DatasetID(), storeID)
resp, err := svc.Projects.Locations.Datasets.DicomStores.SearchForStudies(name, "studies").Do(opts...)
if err != nil {
return nil, fmt.Errorf("failed to search dicom studies: %w", err)
}
defer resp.Body.Close()
respBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("could not read response: %w", err)
}
if resp.StatusCode > 299 {
return nil, fmt.Errorf("search: status %d %s: %s", resp.StatusCode, resp.Status, respBytes)
}
if len(respBytes) == 0 {
return []interface{}{}, nil
}
var result []interface{}
if err := json.Unmarshal([]byte(string(respBytes)), &result); err != nil {
return nil, fmt.Errorf("could not unmarshal response as list: %w", err)
}
return result, nil
dicomWebPath := "studies"
return source.SearchDICOM(t.Kind, storeID, dicomWebPath, tokenStr, opts)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -0,0 +1,180 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlcreatebackup
import (
"context"
"fmt"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-create-backup"
var _ tools.ToolConfig = Config{}
type compatibleSource interface {
GetDefaultProject() string
GetService(context.Context, string) (*sqladmin.Service, error)
UseClientAuthorization() bool
InsertBackupRun(ctx context.Context, project, instance, location, backupDescription, accessToken string) (any, error)
}
// Config defines the configuration for the create-backup tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Description string `yaml:"description"`
Source string `yaml:"source" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
// ToolConfigKind returns the kind of the tool.
func (cfg Config) ToolConfigKind() string {
return kind
}
// Initialize initializes the tool from the configuration.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
}
project := s.GetDefaultProject()
var projectParam parameters.Parameter
if project != "" {
projectParam = parameters.NewStringParameterWithDefault("project", project, "The GCP project ID. This is pre-configured; do not ask for it unless the user explicitly provides a different one.")
} else {
projectParam = parameters.NewStringParameter("project", "The project ID")
}
allParameters := parameters.Parameters{
projectParam,
parameters.NewStringParameter("instance", "Cloud SQL instance ID. This does not include the project ID."),
// Location and backup_description are optional.
parameters.NewStringParameterWithRequired("location", "Location of the backup run.", false),
parameters.NewStringParameterWithRequired("backup_description", "The description of this backup run.", false),
}
paramManifest := allParameters.Manifest()
description := cfg.Description
if description == "" {
description = "Creates a backup on a Cloud SQL instance."
}
mcpManifest := tools.GetMcpManifest(cfg.Name, description, cfg.AuthRequired, allParameters, nil)
return Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
// Tool represents the create-backup tool.
type Tool struct {
Config
AllParams parameters.Parameters `yaml:"allParams"`
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
if !ok {
return nil, fmt.Errorf("error casting 'project' parameter: %v", paramsMap["project"])
}
instance, ok := paramsMap["instance"].(string)
if !ok {
return nil, fmt.Errorf("error casting 'instance' parameter: %v", paramsMap["instance"])
}
location, _ := paramsMap["location"].(string)
description, _ := paramsMap["backup_description"].(string)
return source.InsertBackupRun(ctx, project, instance, location, description, string(accessToken))
}
// ParseParams parses the parameters for the tool.
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.AllParams, data, claims)
}
func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) {
return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil)
}
// Manifest returns the tool's manifest.
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
// McpManifest returns the tool's MCP manifest.
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
// Authorized checks if the tool is authorized.
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
return false, err
}
return source.UseClientAuthorization(), nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -0,0 +1,72 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlcreatebackup_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlcreatebackup"
)
func TestParseFromYaml(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
create-backup-tool:
kind: cloud-sql-create-backup
description: a test description
source: a-source
`,
want: server.ToolConfigs{
"create-backup-tool": cloudsqlcreatebackup.Config{
Name: "create-backup-tool",
Kind: "cloud-sql-create-backup",
Description: "a test description",
Source: "a-source",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -18,7 +18,6 @@ import (
"context"
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
CatalogClient() *dataplexapi.CatalogClient
LookupEntry(context.Context, string, int, []string, string) (*dataplexpb.Entry, error)
}
type Config struct {
@@ -118,12 +117,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
paramsMap := params.AsMap()
viewMap := map[int]dataplexpb.EntryView{
1: dataplexpb.EntryView_BASIC,
2: dataplexpb.EntryView_FULL,
3: dataplexpb.EntryView_CUSTOM,
4: dataplexpb.EntryView_ALL,
}
name, _ := paramsMap["name"].(string)
entry, _ := paramsMap["entry"].(string)
view, _ := paramsMap["view"].(int)
@@ -132,19 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
}
aspectTypes := aspectTypeSlice.([]string)
req := &dataplexpb.LookupEntryRequest{
Name: name,
View: viewMap[view],
AspectTypes: aspectTypes,
Entry: entry,
}
result, err := source.CatalogClient().LookupEntry(ctx, req)
if err != nil {
return nil, err
}
return result, nil
return source.LookupEntry(ctx, name, view, aspectTypes, entry)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -18,9 +18,7 @@ import (
"context"
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/cenkalti/backoff/v5"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -45,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
CatalogClient() *dataplexapi.CatalogClient
ProjectID() string
SearchAspectTypes(context.Context, string, int, string) ([]*dataplexpb.AspectType, error)
}
type Config struct {
@@ -101,61 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
// Invoke the tool with the provided parameters
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize := int32(paramsMap["pageSize"].(int))
pageSize, _ := paramsMap["pageSize"].(int)
orderBy, _ := paramsMap["orderBy"].(string)
// Create SearchEntriesRequest with the provided parameters
req := &dataplexpb.SearchEntriesRequest{
Query: query + " type=projects/dataplex-types/locations/global/entryTypes/aspecttype",
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
// Perform the search using the CatalogClient - this will return an iterator
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
// Create an instance of exponential backoff with default values for retrying GetAspectType calls
// InitialInterval, RandomizationFactor, Multiplier, MaxInterval = 500 ms, 0.5, 1.5, 60 s
getAspectBackOff := backoff.NewExponentialBackOff()
// Iterate through the search results and call GetAspectType for each result using the resource name
var results []*dataplexpb.AspectType
for {
entry, err := it.Next()
if err != nil {
break
}
resourceName := entry.DataplexEntry.GetEntrySource().Resource
getAspectTypeReq := &dataplexpb.GetAspectTypeRequest{
Name: resourceName,
}
operation := func() (*dataplexpb.AspectType, error) {
aspectType, err := source.CatalogClient().GetAspectType(ctx, getAspectTypeReq)
if err != nil {
return nil, fmt.Errorf("failed to get aspect type for entry %q: %w", resourceName, err)
}
return aspectType, nil
}
// Retry the GetAspectType operation with exponential backoff
aspectType, err := backoff.Retry(ctx, operation, backoff.WithBackOff(getAspectBackOff))
if err != nil {
return nil, fmt.Errorf("failed to get aspect type after retries for entry %q: %w", resourceName, err)
}
results = append(results, aspectType)
}
return results, nil
return source.SearchAspectTypes(ctx, query, pageSize, orderBy)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -18,8 +18,7 @@ import (
"context"
"fmt"
dataplexapi "cloud.google.com/go/dataplex/apiv1"
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
"cloud.google.com/go/dataplex/apiv1/dataplexpb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -44,8 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
}
type compatibleSource interface {
CatalogClient() *dataplexapi.CatalogClient
ProjectID() string
SearchEntries(context.Context, string, int, string) ([]*dataplexpb.SearchEntriesResult, error)
}
type Config struct {
@@ -100,34 +98,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
query, _ := paramsMap["query"].(string)
pageSize := int32(paramsMap["pageSize"].(int))
pageSize, _ := paramsMap["pageSize"].(int)
orderBy, _ := paramsMap["orderBy"].(string)
req := &dataplexpb.SearchEntriesRequest{
Query: query,
Name: fmt.Sprintf("projects/%s/locations/global", source.ProjectID()),
PageSize: pageSize,
OrderBy: orderBy,
SemanticSearch: true,
}
it := source.CatalogClient().SearchEntries(ctx, req)
if it == nil {
return nil, fmt.Errorf("failed to create search entries iterator for project %q", source.ProjectID())
}
var results []*dataplexpb.SearchEntriesResult
for {
entry, err := it.Next()
if err != nil {
break
}
results = append(results, entry)
}
return results, nil
return source.SearchEntries(ctx, query, pageSize, orderBy)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -48,6 +48,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
AddDocuments(context.Context, string, any, bool) (map[string]any, error)
}
type Config struct {
@@ -134,24 +135,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
mapParams := params.AsMap()
// Get collection path
collectionPath, ok := mapParams[collectionPathKey].(string)
if !ok || collectionPath == "" {
return nil, fmt.Errorf("invalid or missing '%s' parameter", collectionPathKey)
}
// Validate collection path
if err := util.ValidateCollectionPath(collectionPath); err != nil {
return nil, fmt.Errorf("invalid collection path: %w", err)
}
// Get document data
documentDataRaw, ok := mapParams[documentDataKey]
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter", documentDataKey)
}
// Convert the document data from JSON format to Firestore format
// The client is passed to handle referenceValue types
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
@@ -164,30 +161,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
}
// Get the collection reference
collection := source.FirestoreClient().Collection(collectionPath)
// Add the document to the collection
docRef, writeResult, err := collection.Add(ctx, documentData)
if err != nil {
return nil, fmt.Errorf("failed to add document: %w", err)
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"createTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Convert the document data back to simple JSON format
simplifiedData := util.FirestoreValueToJSON(documentData)
response["documentData"] = simplifiedData
}
return response, nil
return source.AddDocuments(ctx, collectionPath, documentData, returnData)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
DeleteDocuments(context.Context, []string) ([]any, error)
}
type Config struct {
@@ -104,7 +105,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok {
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
}
if len(documentPathsRaw) == 0 {
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
}
@@ -126,45 +126,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
}
}
// Create a BulkWriter to handle multiple deletions efficiently
bulkWriter := source.FirestoreClient().BulkWriter(ctx)
// Keep track of jobs for each document
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
// Add all delete operations to the BulkWriter
for i, path := range documentPaths {
docRef := source.FirestoreClient().Doc(path)
job, err := bulkWriter.Delete(docRef)
if err != nil {
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
}
jobs[i] = job
}
// End the BulkWriter to execute all operations
bulkWriter.End()
// Collect results
results := make([]any, len(documentPaths))
for i, job := range jobs {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
// Wait for the job to complete and get the result
_, err := job.Results()
if err != nil {
docData["success"] = false
docData["error"] = err.Error()
} else {
docData["success"] = true
}
results[i] = docData
}
return results, nil
return source.DeleteDocuments(ctx, documentPaths)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
GetDocuments(context.Context, []string) ([]any, error)
}
type Config struct {
@@ -126,37 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid document path at index %d: %w", i, err)
}
}
// Create document references from paths
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
for i, path := range documentPaths {
docRefs[i] = source.FirestoreClient().Doc(path)
}
// Get all documents
snapshots, err := source.FirestoreClient().GetAll(ctx, docRefs)
if err != nil {
return nil, fmt.Errorf("failed to get documents: %w", err)
}
// Convert snapshots to response data
results := make([]any, len(snapshots))
for i, snapshot := range snapshots {
docData := make(map[string]any)
docData["path"] = documentPaths[i]
docData["exists"] = snapshot.Exists()
if snapshot.Exists() {
docData["data"] = snapshot.Data()
docData["createTime"] = snapshot.CreateTime
docData["updateTime"] = snapshot.UpdateTime
docData["readTime"] = snapshot.ReadTime
}
results[i] = docData
}
return results, nil
return source.GetDocuments(ctx, documentPaths)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -44,8 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirebaseRulesClient() *firebaserules.Service
GetProjectId() string
GetDatabaseId() string
GetRules(context.Context) (any, error)
}
type Config struct {
@@ -98,29 +97,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
// Get the latest release for Firestore
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore/%s", source.GetProjectId(), source.GetDatabaseId())
release, err := source.FirebaseRulesClient().Projects.Releases.Get(releaseName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
}
if release.RulesetName == "" {
return nil, fmt.Errorf("no active Firestore rules were found in project '%s' and database '%s'", source.GetProjectId(), source.GetDatabaseId())
}
// Get the ruleset content
ruleset, err := source.FirebaseRulesClient().Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
}
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
return nil, fmt.Errorf("no rules files found in ruleset")
}
return ruleset, nil
return source.GetRules(ctx)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -46,6 +46,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
ListCollections(context.Context, string) ([]any, error)
}
type Config struct {
@@ -102,47 +103,15 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
mapParams := params.AsMap()
var collectionRefs []*firestoreapi.CollectionRef
// Check if parentPath is provided
parentPath, hasParent := mapParams[parentPathKey].(string)
if hasParent && parentPath != "" {
parentPath, _ := mapParams[parentPathKey].(string)
if parentPath != "" {
// Validate parent document path
if err := util.ValidateDocumentPath(parentPath); err != nil {
return nil, fmt.Errorf("invalid parent document path: %w", err)
}
// List subcollections of the specified document
docRef := source.FirestoreClient().Doc(parentPath)
collectionRefs, err = docRef.Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
}
} else {
// List root collections
collectionRefs, err = source.FirestoreClient().Collections(ctx).GetAll()
if err != nil {
return nil, fmt.Errorf("failed to list root collections: %w", err)
}
}
// Convert collection references to response data
results := make([]any, len(collectionRefs))
for i, collRef := range collectionRefs {
collData := make(map[string]any)
collData["id"] = collRef.ID
collData["path"] = collRef.Path
// If this is a subcollection, include parent information
if collRef.Parent != nil {
collData["parent"] = collRef.Parent.Path
}
results[i] = collData
}
return results, nil
return source.ListCollections(ctx, parentPath)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -36,27 +36,6 @@ const (
defaultLimit = 100
)
// Firestore operators
var validOperators = map[string]bool{
"<": true,
"<=": true,
">": true,
">=": true,
"==": true,
"!=": true,
"array-contains": true,
"array-contains-any": true,
"in": true,
"not-in": true,
}
// Error messages
const (
errFilterParseFailed = "failed to parse filters: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errLimitParseFailed = "failed to parse limit value '%s': %w"
)
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
@@ -74,6 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
// compatibleSource defines the interface for sources that can provide a Firestore client
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
}
// Config represents the configuration for the Firestore query tool
@@ -139,15 +120,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// SimplifiedFilter represents the simplified filter format
type SimplifiedFilter struct {
And []SimplifiedFilter `json:"and,omitempty"`
Or []SimplifiedFilter `json:"or,omitempty"`
Field string `json:"field,omitempty"`
Op string `json:"op,omitempty"`
Value interface{} `json:"value,omitempty"`
}
// OrderByConfig represents ordering configuration
type OrderByConfig struct {
Field string `json:"field"`
@@ -162,20 +134,27 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
return firestoreapi.Asc
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime interface{} `json:"createTime,omitempty"`
UpdateTime interface{} `json:"updateTime,omitempty"`
ReadTime interface{} `json:"readTime,omitempty"`
// SimplifiedFilter represents the simplified filter format
type SimplifiedFilter struct {
And []SimplifiedFilter `json:"and,omitempty"`
Or []SimplifiedFilter `json:"or,omitempty"`
Field string `json:"field,omitempty"`
Op string `json:"op,omitempty"`
Value interface{} `json:"value,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
// Firestore operators
var validOperators = map[string]bool{
"<": true,
"<=": true,
">": true,
">=": true,
"==": true,
"!=": true,
"array-contains": true,
"array-contains-any": true,
"in": true,
"not-in": true,
}
// Invoke executes the Firestore query based on the provided parameters
@@ -184,34 +163,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
paramsMap := params.AsMap()
// Process collection path with template substitution
collectionPath, err := parameters.PopulateTemplate("collectionPath", t.CollectionPath, paramsMap)
if err != nil {
return nil, fmt.Errorf("failed to process collection path: %w", err)
}
// Build the query
query, err := t.buildQuery(source, collectionPath, paramsMap)
if err != nil {
return nil, err
}
// Execute the query and return results
return t.executeQuery(ctx, query)
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(source compatibleSource, collectionPath string, params map[string]any) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(collectionPath)
query := collection.Query
var filter firestoreapi.EntityFilter
// Process and apply filters if template is provided
if t.Filters != "" {
// Apply template substitution to filters
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, params)
filtersJSON, err := parameters.PopulateTemplateWithJSON("filters", t.Filters, paramsMap)
if err != nil {
return nil, fmt.Errorf("failed to process filters template: %w", err)
}
@@ -219,48 +182,43 @@ func (t Tool) buildQuery(source compatibleSource, collectionPath string, params
// Parse the simplified filter format
var simplifiedFilter SimplifiedFilter
if err := json.Unmarshal([]byte(filtersJSON), &simplifiedFilter); err != nil {
return nil, fmt.Errorf(errFilterParseFailed, err)
return nil, fmt.Errorf("failed to parse filters: %w", err)
}
// Convert simplified filter to Firestore filter
if filter := t.convertToFirestoreFilter(source, simplifiedFilter); filter != nil {
query = query.WhereEntity(filter)
}
filter = t.convertToFirestoreFilter(source, simplifiedFilter)
}
// Process select fields
selectFields, err := t.processSelectFields(params)
if err != nil {
return nil, err
}
if len(selectFields) > 0 {
query = query.Select(selectFields...)
}
// Process and apply ordering
orderBy, err := t.getOrderBy(params)
orderBy, err := t.getOrderBy(paramsMap)
if err != nil {
return nil, err
}
if orderBy != nil {
query = query.OrderBy(orderBy.Field, orderBy.GetDirection())
// Process select fields
selectFields, err := t.processSelectFields(paramsMap)
if err != nil {
return nil, err
}
// Process and apply limit
limit, err := t.getLimit(params)
limit, err := t.getLimit(paramsMap)
if err != nil {
return nil, err
}
query = query.Limit(limit)
// Apply analyze options if enabled
if t.AnalyzeQuery {
query = query.WithRunOptions(firestoreapi.ExplainOptions{
Analyze: true,
})
// prevent panic when accessing orderBy incase it is nil
var orderByField string
var orderByDirection firestoreapi.Direction
if orderBy != nil {
orderByField = orderBy.Field
orderByDirection = orderBy.GetDirection()
}
return &query, nil
// Build the query
query, err := source.BuildQuery(collectionPath, filter, selectFields, orderByField, orderByDirection, limit, t.AnalyzeQuery)
if err != nil {
return nil, err
}
// Execute the query and return results
return source.ExecuteQuery(ctx, query, t.AnalyzeQuery)
}
// convertToFirestoreFilter converts simplified filter format to Firestore EntityFilter
@@ -409,7 +367,7 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
if processedValue != "" {
parsedLimit, err := strconv.Atoi(processedValue)
if err != nil {
return 0, fmt.Errorf(errLimitParseFailed, processedValue, err)
return 0, fmt.Errorf("failed to parse limit value '%s': %w", processedValue, err)
}
limit = parsedLimit
}
@@ -417,78 +375,6 @@ func (t Tool) getLimit(params map[string]any) (int, error) {
return limit, nil
}
// executeQuery runs the query and formats the results
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf(errQueryExecutionFailed, err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if t.AnalyzeQuery {
explainMetrics, err := t.getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
return results, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
// ParseParams parses and validates input parameters
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)

View File

@@ -69,7 +69,6 @@ const (
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
errMissingFilterValue = "no value specified for filter on field '%s'"
errOrderByParseFailed = "failed to parse orderBy: %w"
errQueryExecutionFailed = "failed to execute query: %w"
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
)
@@ -90,6 +89,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
// compatibleSource defines the interface for sources that can provide a Firestore client
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
BuildQuery(string, firestoreapi.EntityFilter, []string, string, firestoreapi.Direction, int, bool) (*firestoreapi.Query, error)
ExecuteQuery(context.Context, *firestoreapi.Query, bool) (any, error)
}
// Config represents the configuration for the Firestore query collection tool
@@ -228,22 +229,6 @@ func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
return firestoreapi.Asc
}
// QueryResult represents a document result from the query
type QueryResult struct {
ID string `json:"id"`
Path string `json:"path"`
Data map[string]any `json:"data"`
CreateTime interface{} `json:"createTime,omitempty"`
UpdateTime interface{} `json:"updateTime,omitempty"`
ReadTime interface{} `json:"readTime,omitempty"`
}
// QueryResponse represents the full response including optional metrics
type QueryResponse struct {
Documents []QueryResult `json:"documents"`
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
}
// Invoke executes the Firestore query based on the provided parameters
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
@@ -257,14 +242,37 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
var filter firestoreapi.EntityFilter
// Apply filters
if len(queryParams.Filters) > 0 {
filterConditions := make([]firestoreapi.EntityFilter, 0, len(queryParams.Filters))
for _, filter := range queryParams.Filters {
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
Path: filter.Field,
Operator: filter.Op,
Value: filter.Value,
})
}
filter = firestoreapi.AndFilter{
Filters: filterConditions,
}
}
// prevent panic incase queryParams.OrderBy is nil
var orderByField string
var orderByDirection firestoreapi.Direction
if queryParams.OrderBy != nil {
orderByField = queryParams.OrderBy.Field
orderByDirection = queryParams.OrderBy.GetDirection()
}
// Build the query
query, err := t.buildQuery(source, queryParams)
query, err := source.BuildQuery(queryParams.CollectionPath, filter, nil, orderByField, orderByDirection, queryParams.Limit, queryParams.AnalyzeQuery)
if err != nil {
return nil, err
}
// Execute the query and return results
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
return source.ExecuteQuery(ctx, query, queryParams.AnalyzeQuery)
}
// queryParameters holds all parsed query parameters
@@ -380,122 +388,6 @@ func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
return &orderBy, nil
}
// buildQuery constructs the Firestore query from parameters
func (t Tool) buildQuery(source compatibleSource, params *queryParameters) (*firestoreapi.Query, error) {
collection := source.FirestoreClient().Collection(params.CollectionPath)
query := collection.Query
// Apply filters
if len(params.Filters) > 0 {
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
for _, filter := range params.Filters {
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
Path: filter.Field,
Operator: filter.Op,
Value: filter.Value,
})
}
query = query.WhereEntity(firestoreapi.AndFilter{
Filters: filterConditions,
})
}
// Apply ordering
if params.OrderBy != nil {
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
}
// Apply limit
query = query.Limit(params.Limit)
// Apply analyze options
if params.AnalyzeQuery {
query = query.WithRunOptions(firestoreapi.ExplainOptions{
Analyze: true,
})
}
return &query, nil
}
// executeQuery runs the query and formats the results
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
docIterator := query.Documents(ctx)
docs, err := docIterator.GetAll()
if err != nil {
return nil, fmt.Errorf(errQueryExecutionFailed, err)
}
// Convert results to structured format
results := make([]QueryResult, len(docs))
for i, doc := range docs {
results[i] = QueryResult{
ID: doc.Ref.ID,
Path: doc.Ref.Path,
Data: doc.Data(),
CreateTime: doc.CreateTime,
UpdateTime: doc.UpdateTime,
ReadTime: doc.ReadTime,
}
}
// Return with explain metrics if requested
if analyzeQuery {
explainMetrics, err := t.getExplainMetrics(docIterator)
if err == nil && explainMetrics != nil {
response := QueryResponse{
Documents: results,
ExplainMetrics: explainMetrics,
}
return response, nil
}
}
// Return just the documents
resultsAny := make([]any, len(results))
for i, r := range results {
resultsAny[i] = r
}
return resultsAny, nil
}
// getExplainMetrics extracts explain metrics from the query iterator
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
explainMetrics, err := docIterator.ExplainMetrics()
if err != nil || explainMetrics == nil {
return nil, err
}
metricsData := make(map[string]any)
// Add plan summary if available
if explainMetrics.PlanSummary != nil {
planSummary := make(map[string]any)
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
metricsData["planSummary"] = planSummary
}
// Add execution stats if available
if explainMetrics.ExecutionStats != nil {
executionStats := make(map[string]any)
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
}
if explainMetrics.ExecutionStats.DebugStats != nil {
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
}
metricsData["executionStats"] = executionStats
}
return metricsData, nil
}
// ParseParams parses and validates input parameters
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.Parameters, data, claims)

View File

@@ -50,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirestoreClient() *firestoreapi.Client
UpdateDocument(context.Context, string, []firestoreapi.Update, any, bool) (map[string]any, error)
}
type Config struct {
@@ -177,23 +178,10 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
}
}
// Get return document data flag
returnData := false
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
}
// Get the document reference
docRef := source.FirestoreClient().Doc(documentPath)
// Prepare update data
var writeResult *firestoreapi.WriteResult
var writeErr error
// Use selective field update with update mask
updates := make([]firestoreapi.Update, 0, len(updatePaths))
var documentData any
if len(updatePaths) > 0 {
// Use selective field update with update mask
updates := make([]firestoreapi.Update, 0, len(updatePaths))
// Convert document data without delete markers
dataMap, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
@@ -220,41 +208,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
Value: value,
})
}
writeResult, writeErr = docRef.Update(ctx, updates)
} else {
// Update all fields in the document data (merge)
documentData, err := util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
documentData, err = util.JSONToFirestoreValue(documentDataRaw, source.FirestoreClient())
if err != nil {
return nil, fmt.Errorf("failed to convert document data: %w", err)
}
writeResult, writeErr = docRef.Set(ctx, documentData, firestoreapi.MergeAll)
}
if writeErr != nil {
return nil, fmt.Errorf("failed to update document: %w", writeErr)
// Get return document data flag
returnData := false
if val, ok := mapParams[returnDocumentDataKey].(bool); ok {
returnData = val
}
// Build the response
response := map[string]any{
"documentPath": docRef.Path,
"updateTime": writeResult.UpdateTime.Format("2006-01-02T15:04:05.999999999Z"),
}
// Add document data if requested
if returnData {
// Fetch the updated document to return the current state
snapshot, err := docRef.Get(ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve updated document: %w", err)
}
// Convert the document data to simple JSON format
simplifiedData := util.FirestoreValueToJSON(snapshot.Data())
response["documentData"] = simplifiedData
}
return response, nil
return source.UpdateDocument(ctx, documentPath, updates, documentData, returnData)
}
// getFieldValue retrieves a value from a nested map using a dot-separated path

View File

@@ -17,7 +17,6 @@ package firestorevalidaterules
import (
"context"
"fmt"
"strings"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
@@ -50,7 +49,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
FirebaseRulesClient() *firebaserules.Service
GetProjectId() string
ValidateRules(context.Context, string) (any, error)
}
type Config struct {
@@ -107,30 +106,6 @@ func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// Issue represents a validation issue in the rules
type Issue struct {
SourcePosition SourcePosition `json:"sourcePosition"`
Description string `json:"description"`
Severity string `json:"severity"`
}
// SourcePosition represents the location of an issue in the source
type SourcePosition struct {
FileName string `json:"fileName,omitempty"`
Line int64 `json:"line"` // 1-based
Column int64 `json:"column"` // 1-based
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
}
// ValidationResult represents the result of rules validation
type ValidationResult struct {
Valid bool `json:"valid"`
IssueCount int `json:"issueCount"`
FormattedIssues string `json:"formattedIssues,omitempty"`
RawIssues []Issue `json:"rawIssues,omitempty"`
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
if err != nil {
@@ -144,114 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok || sourceParam == "" {
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
}
// Create test request
testRequest := &firebaserules.TestRulesetRequest{
Source: &firebaserules.Source{
Files: []*firebaserules.File{
{
Name: "firestore.rules",
Content: sourceParam,
},
},
},
// We don't need test cases for validation only
TestSuite: &firebaserules.TestSuite{
TestCases: []*firebaserules.TestCase{},
},
}
// Call the test API
projectName := fmt.Sprintf("projects/%s", source.GetProjectId())
response, err := source.FirebaseRulesClient().Projects.Test(projectName, testRequest).Context(ctx).Do()
if err != nil {
return nil, fmt.Errorf("failed to validate rules: %w", err)
}
// Process the response
result := t.processValidationResponse(response, sourceParam)
return result, nil
}
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
if len(response.Issues) == 0 {
return ValidationResult{
Valid: true,
IssueCount: 0,
FormattedIssues: "✓ No errors detected. Rules are valid.",
}
}
// Convert issues to our format
issues := make([]Issue, len(response.Issues))
for i, issue := range response.Issues {
issues[i] = Issue{
Description: issue.Description,
Severity: issue.Severity,
SourcePosition: SourcePosition{
FileName: issue.SourcePosition.FileName,
Line: issue.SourcePosition.Line,
Column: issue.SourcePosition.Column,
CurrentOffset: issue.SourcePosition.CurrentOffset,
EndOffset: issue.SourcePosition.EndOffset,
},
}
}
// Format issues
formattedIssues := t.formatRulesetIssues(issues, source)
return ValidationResult{
Valid: false,
IssueCount: len(issues),
FormattedIssues: formattedIssues,
RawIssues: issues,
}
}
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
sourceLines := strings.Split(rulesSource, "\n")
var formattedOutput []string
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
for _, issue := range issues {
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
issue.Severity,
issue.Description,
issue.SourcePosition.Line,
issue.SourcePosition.Column)
if issue.SourcePosition.Line > 0 {
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
if lineIndex >= 0 && lineIndex < len(sourceLines) {
errorLine := sourceLines[lineIndex]
issueString += fmt.Sprintf("\n```\n%s", errorLine)
// Add carets if we have column and offset information
if issue.SourcePosition.Column > 0 &&
issue.SourcePosition.CurrentOffset >= 0 &&
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
padding := strings.Repeat(" ", startColumn)
carets := strings.Repeat("^", errorTokenLength)
issueString += fmt.Sprintf("\n%s%s", padding, carets)
}
}
issueString += "\n```"
}
}
formattedOutput = append(formattedOutput, issueString)
}
return strings.Join(formattedOutput, "\n\n")
return source.ValidateRules(ctx, sourceParam)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -28,13 +28,13 @@ import (
// JSONToFirestoreValue converts a JSON value with type information to a Firestore-compatible value
// The input should be a map with a single key indicating the type (e.g., "stringValue", "integerValue")
// If a client is provided, referenceValue types will be converted to *firestore.DocumentRef
func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interface{}, error) {
func JSONToFirestoreValue(value any, client *firestore.Client) (any, error) {
if value == nil {
return nil, nil
}
switch v := value.(type) {
case map[string]interface{}:
case map[string]any:
// Check for typed values
if len(v) == 1 {
for key, val := range v {
@@ -92,7 +92,7 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
return nil, fmt.Errorf("timestamp value must be a string")
case "geoPointValue":
// Convert to LatLng
if geoMap, ok := val.(map[string]interface{}); ok {
if geoMap, ok := val.(map[string]any); ok {
lat, latOk := geoMap["latitude"].(float64)
lng, lngOk := geoMap["longitude"].(float64)
if latOk && lngOk {
@@ -105,9 +105,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
return nil, fmt.Errorf("invalid geopoint value format")
case "arrayValue":
// Convert array
if arrayMap, ok := val.(map[string]interface{}); ok {
if values, ok := arrayMap["values"].([]interface{}); ok {
result := make([]interface{}, len(values))
if arrayMap, ok := val.(map[string]any); ok {
if values, ok := arrayMap["values"].([]any); ok {
result := make([]any, len(values))
for i, item := range values {
converted, err := JSONToFirestoreValue(item, client)
if err != nil {
@@ -121,9 +121,9 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
return nil, fmt.Errorf("invalid array value format")
case "mapValue":
// Convert map
if mapMap, ok := val.(map[string]interface{}); ok {
if fields, ok := mapMap["fields"].(map[string]interface{}); ok {
result := make(map[string]interface{})
if mapMap, ok := val.(map[string]any); ok {
if fields, ok := mapMap["fields"].(map[string]any); ok {
result := make(map[string]any)
for k, v := range fields {
converted, err := JSONToFirestoreValue(v, client)
if err != nil {
@@ -160,8 +160,8 @@ func JSONToFirestoreValue(value interface{}, client *firestore.Client) (interfac
}
// convertPlainMap converts a plain map to Firestore format
func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[string]interface{}, error) {
result := make(map[string]interface{})
func convertPlainMap(m map[string]any, client *firestore.Client) (map[string]any, error) {
result := make(map[string]any)
for k, v := range m {
converted, err := JSONToFirestoreValue(v, client)
if err != nil {
@@ -172,42 +172,6 @@ func convertPlainMap(m map[string]interface{}, client *firestore.Client) (map[st
return result, nil
}
// FirestoreValueToJSON converts a Firestore value to a simplified JSON representation
// This removes type information and returns plain values
func FirestoreValueToJSON(value interface{}) interface{} {
if value == nil {
return nil
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case *latlng.LatLng:
return map[string]interface{}{
"latitude": v.Latitude,
"longitude": v.Longitude,
}
case []byte:
return base64.StdEncoding.EncodeToString(v)
case []interface{}:
result := make([]interface{}, len(v))
for i, item := range v {
result[i] = FirestoreValueToJSON(item)
}
return result
case map[string]interface{}:
result := make(map[string]interface{})
for k, val := range v {
result[k] = FirestoreValueToJSON(val)
}
return result
case *firestore.DocumentRef:
return v.Path
default:
return value
}
}
// isValidDocumentPath checks if a string is a valid Firestore document path
// Valid paths have an even number of segments (collection/doc/collection/doc...)
func isValidDocumentPath(path string) bool {

View File

@@ -312,40 +312,6 @@ func TestJSONToFirestoreValue_IntegerFromString(t *testing.T) {
}
}
func TestFirestoreValueToJSON_RoundTrip(t *testing.T) {
// Test round-trip conversion
original := map[string]interface{}{
"name": "Test",
"count": int64(42),
"price": 19.99,
"active": true,
"tags": []interface{}{"tag1", "tag2"},
"metadata": map[string]interface{}{
"created": time.Now(),
},
"nullField": nil,
}
// Convert to JSON representation
jsonRepresentation := FirestoreValueToJSON(original)
// Verify types are simplified
jsonMap, ok := jsonRepresentation.(map[string]interface{})
if !ok {
t.Fatalf("Expected map, got %T", jsonRepresentation)
}
// Time should be converted to string
metadata, ok := jsonMap["metadata"].(map[string]interface{})
if !ok {
t.Fatalf("metadata should be a map, got %T", jsonMap["metadata"])
}
_, ok = metadata["created"].(string)
if !ok {
t.Errorf("created should be a string, got %T", metadata["created"])
}
}
func TestJSONToFirestoreValue_InvalidFormats(t *testing.T) {
tests := []struct {
name string

View File

@@ -16,9 +16,7 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"slices"
@@ -54,7 +52,7 @@ type compatibleSource interface {
HttpDefaultHeaders() map[string]string
HttpBaseURL() string
HttpQueryParams() map[string]string
Client() *http.Client
RunRequest(*http.Request) (any, error)
}
type Config struct {
@@ -259,29 +257,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
for k, v := range allHeaders {
req.Header.Set(k, v)
}
// Make request and fetch response
resp, err := source.Client().Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %s", err)
}
defer resp.Body.Close()
var body []byte
body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var data any
if err = json.Unmarshal(body, &data); err != nil {
// if unable to unmarshal data, return result as string.
return string(body), nil
}
return data, nil
return source.RunRequest(req)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -159,7 +159,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
visConfig := paramsMap["vis_config"].(map[string]any)
wq.VisConfig = &visConfig
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -192,7 +191,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
req.Dimension = &dimension
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -15,65 +15,17 @@ package lookercommon
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
rtl "github.com/looker-open-source/sdk-codegen/go/rtl"
"github.com/looker-open-source/sdk-codegen/go/rtl"
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
"github.com/thlib/go-timezone-local/tzlocal"
)
// Make types for RoundTripper
type transportWithAuthHeader struct {
Base http.RoundTripper
AuthToken tools.AccessToken
}
func (t *transportWithAuthHeader) RoundTrip(req *http.Request) (*http.Response, error) {
req.Header.Set("x-looker-appid", "go-sdk")
req.Header.Set("Authorization", string(t.AuthToken))
return t.Base.RoundTrip(req)
}
func GetLookerSDK(useClientOAuth bool, config *rtl.ApiSettings, client *v4.LookerSDK, accessToken tools.AccessToken) (*v4.LookerSDK, error) {
if useClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("no access token supplied with request")
}
session := rtl.NewAuthSession(*config)
// Configure base transport with TLS
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: !config.VerifySsl,
},
}
// Build transport for end user token
session.Client = http.Client{
Transport: &transportWithAuthHeader{
Base: transport,
AuthToken: accessToken,
},
}
// return SDK with new Transport
return v4.NewLookerSDK(session), nil
}
if client == nil {
return nil, fmt.Errorf("client id or client secret not valid")
}
return client, nil
}
const (
DimensionsFields = "fields(dimensions(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"
FiltersFields = "fields(filters(name,type,label,label_short,description,synonyms,tags,hidden,suggestable,suggestions,suggest_dimension,suggest_explore))"

View File

@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -116,7 +116,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -47,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -117,7 +117,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -125,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'devMode' must be a boolean, got %T", mapParams["devMode"])
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -49,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerSessionLength() int64
}
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
contentId_ptr = nil
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/looker-open-source/sdk-codegen/go/rtl"
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'conn' must be a string, got %T", mapParams["conn"])
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/looker-open-source/sdk-codegen/go/rtl"
@@ -47,8 +46,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -122,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
db, _ := mapParams["db"].(string)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -137,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'tables' must be a string, got %T", mapParams["tables"])
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -132,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'schema' must be a string, got %T", mapParams["schema"])
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
limit := int64(paramsMap["limit"].(int))
offset := int64(paramsMap["offset"].(int))
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenFields() bool
}
@@ -124,7 +124,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error processing model or explore: %w", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenExplores() bool
}
@@ -126,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("'model' must be a string, got %T", mapParams["model"])
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenFields() bool
}
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
fields := lookercommon.FiltersFields
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -141,7 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
limit := int64(paramsMap["limit"].(int))
offset := int64(paramsMap["offset"].(int))
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenFields() bool
}
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
fields := lookercommon.MeasuresFields
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenModels() bool
}
@@ -124,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
excludeHidden := !source.LookerShowHiddenModels()
includeInternal := true
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
LookerShowHiddenFields() bool
}
@@ -125,7 +125,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
fields := lookercommon.ParametersFields
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -121,7 +121,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -120,7 +119,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -48,8 +47,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -119,7 +118,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -53,8 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -136,7 +136,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -53,8 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -127,7 +127,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -53,8 +53,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -131,7 +131,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, err
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/looker/lookercommon"
"github.com/googleapis/genai-toolbox/internal/util"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
@@ -50,8 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -129,7 +128,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
logger.DebugContext(ctx, "params = ", params)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -50,8 +50,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -139,7 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("error building query request: %w", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -49,8 +49,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -123,7 +123,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, fmt.Errorf("error building WriteQuery request: %w", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -122,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, fmt.Errorf("error building query request: %w", err)
}
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -48,8 +48,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -135,7 +135,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
visConfig := paramsMap["vis_config"].(map[string]any)
wq.VisConfig = &visConfig
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

View File

@@ -50,8 +50,8 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
UseClientAuthorization() bool
GetAuthTokenHeaderName() string
LookerClient() *v4.LookerSDK
LookerApiSettings() *rtl.ApiSettings
GetLookerSDK(string) (*v4.LookerSDK, error)
}
type Config struct {
@@ -129,7 +129,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
dashboard_id := paramsMap["dashboard_id"].(string)
sdk, err := lookercommon.GetLookerSDK(source.UseClientAuthorization(), source.LookerApiSettings(), source.LookerClient(), accessToken)
sdk, err := source.GetLookerSDK(string(accessToken))
if err != nil {
return nil, fmt.Errorf("error getting sdk: %w", err)
}

Some files were not shown because too many files have changed in this diff Show More