mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
Compare commits
31 Commits
map
...
js-quickst
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
540cea3d1a | ||
|
|
9334368a42 | ||
|
|
2a650349cb | ||
|
|
5f7cc32127 | ||
|
|
abdab54503 | ||
|
|
e78bce32dc | ||
|
|
3727b1d053 | ||
|
|
7eff0f9ac7 | ||
|
|
023fcb9163 | ||
|
|
4468bc920b | ||
|
|
9a55b80482 | ||
|
|
3c5d2858e3 | ||
|
|
e5ac5ba9ee | ||
|
|
2bb790e4f8 | ||
|
|
2083ba5048 | ||
|
|
53afed5b76 | ||
|
|
a008fa2a5d | ||
|
|
f3599627b8 | ||
|
|
45f626a0d3 | ||
|
|
45183be932 | ||
|
|
ecfcf28d42 | ||
|
|
ddd29d4646 | ||
|
|
bdf9f3717b | ||
|
|
37bce26e4d | ||
|
|
1e90d847f9 | ||
|
|
138419b453 | ||
|
|
588a690dde | ||
|
|
2d65ff5c63 | ||
|
|
9cc63bc52b | ||
|
|
c673f27b44 | ||
|
|
fe7caa08d6 |
@@ -425,6 +425,26 @@ steps:
|
||||
"Valkey" \
|
||||
valkey \
|
||||
valkey
|
||||
|
||||
- id: "firestore"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "FIRESTORE_PROJECT=$PROJECT_ID"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Firestore" \
|
||||
firestore \
|
||||
firestore
|
||||
|
||||
|
||||
availableSecrets:
|
||||
|
||||
2
.github/blunderbuss.yml
vendored
2
.github/blunderbuss.yml
vendored
@@ -1,5 +1,4 @@
|
||||
assign_issues:
|
||||
- kurtisvg
|
||||
- Yuan325
|
||||
- duwenxin99
|
||||
- akitsch
|
||||
@@ -11,7 +10,6 @@ assign_issues_by:
|
||||
- shobsi
|
||||
- jiaxunwu
|
||||
assign_prs:
|
||||
- kurtisvg
|
||||
- Yuan325
|
||||
- duwenxin99
|
||||
- akitsch
|
||||
|
||||
@@ -52,12 +52,18 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mssql/mssqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/neo4j/neo4jcypher"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgressql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/redis"
|
||||
@@ -77,6 +83,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
|
||||
@@ -1166,6 +1166,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
cloudsqlpg_config, _ := prebuiltconfigs.Get("cloud-sql-postgres")
|
||||
cloudsqlmysql_config, _ := prebuiltconfigs.Get("cloud-sql-mysql")
|
||||
cloudsqlmssql_config, _ := prebuiltconfigs.Get("cloud-sql-mssql")
|
||||
firestoreconfig, _ := prebuiltconfigs.Get("firestore")
|
||||
postgresconfig, _ := prebuiltconfigs.Get("postgres")
|
||||
spanner_config, _ := prebuiltconfigs.Get("spanner")
|
||||
spannerpg_config, _ := prebuiltconfigs.Get("spanner-postgres")
|
||||
@@ -1228,6 +1229,16 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "firestore prebuilt tools",
|
||||
in: firestoreconfig,
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"firestore-database-tools": tools.ToolsetConfig{
|
||||
Name: "firestore-database-tools",
|
||||
ToolNames: []string{"firestore-get-documents", "firestore-list-collections", "firestore-delete-documents", "firestore-query-collection", "firestore-get-rules", "firestore-validate-rules"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "postgres prebuilt tools",
|
||||
in: postgresconfig,
|
||||
|
||||
663
docs/en/getting-started/JS Quickstart (Local).md
Normal file
663
docs/en/getting-started/JS Quickstart (Local).md
Normal file
@@ -0,0 +1,663 @@
|
||||
---
|
||||
title: "JS Quickstart (Local)"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
How to get started running Toolbox locally with JavaScript, PostgreSQL, and orchestration frameworks such as [LangChain](https://js.langchain.com/docs/introduction/), [LlamaIndex](https://ts.llamaindex.ai/), or [GenkitJS](https://genkit.dev/docs/get-started/).
|
||||
---
|
||||
|
||||
## Before you begin
|
||||
|
||||
This guide assumes you have already done the following:
|
||||
|
||||
1. Installed [Node.js (v18 or higher)].
|
||||
2. Installed [PostgreSQL 16+ and the `psql` client][install-postgres].
|
||||
|
||||
### Cloud Setup (Optional)
|
||||
|
||||
If you plan to use **Google Cloud’s Vertex AI** with your agent (e.g., using Gemini or PaLM models), follow these one-time setup steps:
|
||||
|
||||
> - [Install the Google Cloud CLI]
|
||||
> - [Set up Application Default Credentials (ADC)]
|
||||
|
||||
#### Set your project and enable Vertex AI
|
||||
|
||||
```bash
|
||||
gcloud config set project YOUR_PROJECT_ID
|
||||
gcloud services enable aiplatform.googleapis.com
|
||||
```
|
||||
|
||||
[Node.js (v18 or higher)]: https://nodejs.org/
|
||||
[install-postgres]: https://www.postgresql.org/download/
|
||||
[Install the Google Cloud CLI]: https://cloud.google.com/sdk/docs/install
|
||||
[Set up Application Default Credentials (ADC)]: https://cloud.google.com/docs/authentication/set-up-adc-local-dev-environment
|
||||
|
||||
|
||||
## Step 1: Set up your database
|
||||
|
||||
In this section, we will create a database, insert some data that needs to be
|
||||
accessed by our agent, and create a database user for Toolbox to connect with.
|
||||
|
||||
1. Connect to postgres using the `psql` command:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U postgres
|
||||
```
|
||||
|
||||
Here, `postgres` denotes the default postgres superuser.
|
||||
|
||||
{{< notice info >}}
|
||||
|
||||
#### **Having trouble connecting?**
|
||||
|
||||
* **Password Prompt:** If you are prompted for a password for the `postgres`
|
||||
user and do not know it (or a blank password doesn't work), your PostgreSQL
|
||||
installation might require a password or a different authentication method.
|
||||
* **`FATAL: role "postgres" does not exist`:** This error means the default
|
||||
`postgres` superuser role isn't available under that name on your system.
|
||||
* **`Connection refused`:** Ensure your PostgreSQL server is actually running.
|
||||
You can typically check with `sudo systemctl status postgresql` and start it
|
||||
with `sudo systemctl start postgresql` on Linux systems.
|
||||
|
||||
<br/>
|
||||
|
||||
#### **Common Solution**
|
||||
|
||||
For password issues or if the `postgres` role seems inaccessible directly, try
|
||||
switching to the `postgres` operating system user first. This user often has
|
||||
permission to connect without a password for local connections (this is called
|
||||
peer authentication).
|
||||
|
||||
```bash
|
||||
sudo -i -u postgres
|
||||
psql -h 127.0.0.1
|
||||
```
|
||||
|
||||
Once you are in the `psql` shell using this method, you can proceed with the
|
||||
database creation steps below. Afterwards, type `\q` to exit `psql`, and then
|
||||
`exit` to return to your normal user shell.
|
||||
|
||||
If desired, once connected to `psql` as the `postgres` OS user, you can set a
|
||||
password for the `postgres` *database* user using: `ALTER USER postgres WITH
|
||||
PASSWORD 'your_chosen_password';`. This would allow direct connection with `-U
|
||||
postgres` and a password next time.
|
||||
{{< /notice >}}
|
||||
|
||||
1. Create a new database and a new user:
|
||||
|
||||
{{< notice tip >}}
|
||||
For a real application, it's best to follow the principle of least permission
|
||||
and only grant the privileges your application needs.
|
||||
{{< /notice >}}
|
||||
|
||||
```sql
|
||||
CREATE USER toolbox_user WITH PASSWORD 'my-password';
|
||||
|
||||
CREATE DATABASE toolbox_db;
|
||||
GRANT ALL PRIVILEGES ON DATABASE toolbox_db TO toolbox_user;
|
||||
|
||||
ALTER DATABASE toolbox_db OWNER TO toolbox_user;
|
||||
```
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
(If you used `sudo -i -u postgres` and then `psql`, remember you might also
|
||||
need to type `exit` after `\q` to leave the `postgres` user's shell
|
||||
session.)
|
||||
|
||||
1. Connect to your database with your new user:
|
||||
|
||||
```bash
|
||||
psql -h 127.0.0.1 -U toolbox_user -d toolbox_db
|
||||
```
|
||||
|
||||
1. Create a table using the following command:
|
||||
|
||||
```sql
|
||||
CREATE TABLE hotels(
|
||||
id INTEGER NOT NULL PRIMARY KEY,
|
||||
name VARCHAR NOT NULL,
|
||||
location VARCHAR NOT NULL,
|
||||
price_tier VARCHAR NOT NULL,
|
||||
checkin_date DATE NOT NULL,
|
||||
checkout_date DATE NOT NULL,
|
||||
booked BIT NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
1. Insert data into the table.
|
||||
|
||||
```sql
|
||||
INSERT INTO hotels(id, name, location, price_tier, checkin_date, checkout_date, booked)
|
||||
VALUES
|
||||
(1, 'Hilton Basel', 'Basel', 'Luxury', '2024-04-22', '2024-04-20', B'0'),
|
||||
(2, 'Marriott Zurich', 'Zurich', 'Upscale', '2024-04-14', '2024-04-21', B'0'),
|
||||
(3, 'Hyatt Regency Basel', 'Basel', 'Upper Upscale', '2024-04-02', '2024-04-20', B'0'),
|
||||
(4, 'Radisson Blu Lucerne', 'Lucerne', 'Midscale', '2024-04-24', '2024-04-05', B'0'),
|
||||
(5, 'Best Western Bern', 'Bern', 'Upper Midscale', '2024-04-23', '2024-04-01', B'0'),
|
||||
(6, 'InterContinental Geneva', 'Geneva', 'Luxury', '2024-04-23', '2024-04-28', B'0'),
|
||||
(7, 'Sheraton Zurich', 'Zurich', 'Upper Upscale', '2024-04-27', '2024-04-02', B'0'),
|
||||
(8, 'Holiday Inn Basel', 'Basel', 'Upper Midscale', '2024-04-24', '2024-04-09', B'0'),
|
||||
(9, 'Courtyard Zurich', 'Zurich', 'Upscale', '2024-04-03', '2024-04-13', B'0'),
|
||||
(10, 'Comfort Inn Bern', 'Bern', 'Midscale', '2024-04-04', '2024-04-16', B'0');
|
||||
```
|
||||
|
||||
1. End the database session:
|
||||
|
||||
```bash
|
||||
\q
|
||||
```
|
||||
|
||||
## Step 2: Install and configure Toolbox
|
||||
|
||||
In this section, we will download Toolbox, configure our tools in a
|
||||
`tools.yaml`, and then run the Toolbox server.
|
||||
|
||||
1. Download the latest version of Toolbox as a binary:
|
||||
|
||||
{{< notice tip >}}
|
||||
Select the
|
||||
[correct binary](https://github.com/googleapis/genai-toolbox/releases)
|
||||
corresponding to your OS and CPU architecture.
|
||||
{{< /notice >}}
|
||||
<!-- {x-release-please-start-version} -->
|
||||
```bash
|
||||
export OS="linux/amd64" # one of linux/amd64, darwin/arm64, darwin/amd64, or windows/amd64
|
||||
curl -O https://storage.googleapis.com/genai-toolbox/v0.9.0/$OS/toolbox
|
||||
```
|
||||
<!-- {x-release-please-end} -->
|
||||
|
||||
1. Make the binary executable:
|
||||
|
||||
```bash
|
||||
chmod +x toolbox
|
||||
```
|
||||
|
||||
1. Write the following into a `tools.yaml` file. Be sure to update any fields
|
||||
such as `user`, `password`, or `database` that you may have customized in the
|
||||
previous step.
|
||||
|
||||
{{< notice tip >}}
|
||||
In practice, use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-pg-source:
|
||||
kind: postgres
|
||||
host: 127.0.0.1
|
||||
port: 5432
|
||||
database: toolbox_db
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
tools:
|
||||
search-hotels-by-name:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Search for hotels based on name.
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: The name of the hotel.
|
||||
statement: SELECT * FROM hotels WHERE name ILIKE '%' || $1 || '%';
|
||||
search-hotels-by-location:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Search for hotels based on location.
|
||||
parameters:
|
||||
- name: location
|
||||
type: string
|
||||
description: The location of the hotel.
|
||||
statement: SELECT * FROM hotels WHERE location ILIKE '%' || $1 || '%';
|
||||
book-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: >-
|
||||
Book a hotel by its ID. If the hotel is successfully booked, returns a NULL, raises an error if not.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to book.
|
||||
statement: UPDATE hotels SET booked = B'1' WHERE id = $1;
|
||||
update-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: >-
|
||||
Update a hotel's check-in and check-out dates by its ID. Returns a message
|
||||
indicating whether the hotel was successfully updated or not.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to update.
|
||||
- name: checkin_date
|
||||
type: string
|
||||
description: The new check-in date of the hotel.
|
||||
- name: checkout_date
|
||||
type: string
|
||||
description: The new check-out date of the hotel.
|
||||
statement: >-
|
||||
UPDATE hotels SET checkin_date = CAST($2 as date), checkout_date = CAST($3
|
||||
as date) WHERE id = $1;
|
||||
cancel-hotel:
|
||||
kind: postgres-sql
|
||||
source: my-pg-source
|
||||
description: Cancel a hotel by its ID.
|
||||
parameters:
|
||||
- name: hotel_id
|
||||
type: string
|
||||
description: The ID of the hotel to cancel.
|
||||
statement: UPDATE hotels SET booked = B'0' WHERE id = $1;
|
||||
toolsets:
|
||||
my-toolset:
|
||||
- search-hotels-by-name
|
||||
- search-hotels-by-location
|
||||
- book-hotel
|
||||
- update-hotel
|
||||
- cancel-hotel
|
||||
```
|
||||
|
||||
For more info on tools, check out the `Resources` section of the docs.
|
||||
|
||||
1. Run the Toolbox server, pointing to the `tools.yaml` file created earlier:
|
||||
|
||||
```bash
|
||||
./toolbox --tools-file "tools.yaml"
|
||||
```
|
||||
{{< notice note >}}
|
||||
Toolbox enables dynamic reloading by default. To disable, use the `--disable-reload` flag.
|
||||
{{< /notice >}}
|
||||
|
||||
## Step 3: Connect your agent to Toolbox
|
||||
|
||||
In this section, we will write and run an agent that will load the Tools
|
||||
from Toolbox.
|
||||
|
||||
First let's create a new folder for your project, initialize it with `npm`, and install the required dependencies.
|
||||
|
||||
1. Create a new folder for your project and navigate into it:
|
||||
|
||||
```bash
|
||||
mkdir my-agent-app
|
||||
cd my-agent-app
|
||||
```
|
||||
|
||||
1. Initialize a new Node.js project:
|
||||
|
||||
```bash
|
||||
npm init -y
|
||||
```
|
||||
|
||||
1. Create a new file in the root directory:
|
||||
|
||||
```bash
|
||||
touch index.js
|
||||
```
|
||||
|
||||
4. Next, depending on which orchestration framework you want to use, install the relevant dependencies:
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="LangChain" lang="bash" >}}
|
||||
npm install langchain @genai-toolbox/sdk @langchain/google-vertexai dotenv
|
||||
{{< /tab >}}
|
||||
{{< tab header="LlamaIndex" lang="bash" >}}
|
||||
npm install @llamaindex/core @llamaindex/llms-google-genai @genai-toolbox/sdk dotenv
|
||||
{{< /tab >}}
|
||||
{{< tab header="GenkitJS" lang="bash" >}}
|
||||
npm install @toolbox-sdk/core genkit @genkit-ai/vertexai dotenv
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
5. Now copy the below code in your `index.js` file based on your orchestration framework.
|
||||
|
||||
{{< tabpane persist=header >}}
|
||||
{{< tab header="LangChain" lang="js" >}}
|
||||
|
||||
import "dotenv/config";
|
||||
import { ChatVertexAI } from "@langchain/google-vertexai";
|
||||
import { ToolboxClient } from "@toolbox-sdk/core";
|
||||
import { tool } from "@langchain/core/tools";
|
||||
import { HumanMessage, ToolMessage } from "@langchain/core/messages";
|
||||
|
||||
const prompt = `
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking, and
|
||||
cancellations. When the user searches for a hotel, mention its name, id,
|
||||
location and price tier. Always mention hotel ids while performing any
|
||||
searches. This is very important for any operations. For any bookings or
|
||||
cancellations, please provide the appropriate confirmation. Be sure to
|
||||
update checkin or checkout dates if mentioned by the user.
|
||||
Don't ask for confirmations from the user.
|
||||
`;
|
||||
|
||||
const queries = [
|
||||
"Find hotels in Basel with Basel in its name.",
|
||||
"Can you book the Hilton Basel for me?",
|
||||
"Oh wait, this is too expensive. Please cancel it and book the Hyatt Regency instead.",
|
||||
"My check in dates would be from April 10, 2024 to April 19, 2024.",
|
||||
];
|
||||
|
||||
async function runApplication() {
|
||||
console.log("Starting hotel agent...");
|
||||
|
||||
const model = new ChatVertexAI({
|
||||
model: "gemini-2.0-flash-001",
|
||||
temperature: 0,
|
||||
});
|
||||
|
||||
const client = new ToolboxClient("http://127.0.0.1:5000");
|
||||
const toolboxTools = await client.loadToolset("my-toolset");
|
||||
|
||||
console.log(`Loaded ${toolboxTools.length} tools from Toolbox`);
|
||||
|
||||
const tools = toolboxTools
|
||||
.map((t) => {
|
||||
return tool(t, {
|
||||
name: t.toolName,
|
||||
description: t.description,
|
||||
schema: t.params,
|
||||
});
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
const modelWithTools = model.bindTools(tools);
|
||||
|
||||
let messages = [new HumanMessage(prompt)];
|
||||
|
||||
for (const query of queries) {
|
||||
console.log(`\nUser: ${query}`);
|
||||
|
||||
messages.push(new HumanMessage(query));
|
||||
|
||||
for (let step = 0; step < 5; step++) {
|
||||
const response = await modelWithTools.invoke(messages);
|
||||
|
||||
if (!response.tool_calls || response.tool_calls.length === 0) {
|
||||
console.log("Agent:", response.content);
|
||||
messages.push(response);
|
||||
break;
|
||||
}
|
||||
|
||||
console.log("Agent decided to use tools:", response.tool_calls);
|
||||
messages.push(response);
|
||||
|
||||
const toolMessages = await Promise.all(
|
||||
response.tool_calls.map(async (call) => {
|
||||
const toolToCall = tools.find((t) => t.name === call.name);sources:
|
||||
my-valkey-instance:
|
||||
kind: valkey
|
||||
address:
|
||||
- 10.128.0.2:6379
|
||||
useGCPIAM: true
|
||||
|
||||
if (!toolToCall) {
|
||||
return new ToolMessage({
|
||||
content: `Error: Tool ${call.name} not found`,
|
||||
tool_call_id: call.id,
|
||||
});
|
||||
}
|
||||
try {
|
||||
const result = await toolToCall.invoke(call.args);
|
||||
return new ToolMessage({
|
||||
content: JSON.stringify(result ?? "No result returned."),
|
||||
tool_call_id: call.id,
|
||||
});
|
||||
} catch (e) {
|
||||
return new ToolMessage({
|
||||
content: `Error: ${e.message}`,
|
||||
tool_call_id: call.id,
|
||||
});
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
messages.push(...toolMessages);
|
||||
}
|
||||
}
|
||||
if (client.close) {
|
||||
await client.close();
|
||||
}
|
||||
}
|
||||
|
||||
runApplication()
|
||||
.catch(console.error)
|
||||
.finally(() => console.log("\nApplication finished."));
|
||||
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="LlamaIndex" lang="js" >}}
|
||||
|
||||
import "dotenv/config";
|
||||
import { LlamaIndexAgent } from "@llamaindex/core";
|
||||
import { GoogleGenAI } from "@llamaindex/llms-google-genai";
|
||||
import { ToolboxClient } from "@genai-toolbox/sdk";
|
||||
|
||||
// Sample prompt and queries
|
||||
const prompt = `
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking, and cancellations.
|
||||
... (same as above) ...
|
||||
`;
|
||||
|
||||
const queries = [
|
||||
"Find hotels in Basel with Basel in its name.",
|
||||
// ...more queries...
|
||||
];
|
||||
|
||||
async function runApplication() {
|
||||
const llm = new GoogleGenAI({
|
||||
model: "gemini-2.0-flash-001",
|
||||
// Add any required config here
|
||||
});
|
||||
|
||||
const client = new ToolboxClient("http://127.0.0.1:5000");
|
||||
const tools = await client.loadToolset("my-toolset");
|
||||
|
||||
const agent = new LlamaIndexAgent({
|
||||
llm,
|
||||
tools,
|
||||
systemPrompt: prompt,
|
||||
});
|
||||
|
||||
for (const query of queries) {
|
||||
const response = await agent.run(query);
|
||||
console.log(response);
|
||||
}
|
||||
|
||||
if (client.close) await client.close();
|
||||
}
|
||||
|
||||
runApplication().catch(console.error);
|
||||
{{< /tab >}}
|
||||
|
||||
{{< tab header="GenkitJS" lang="js" >}}
|
||||
import { ToolboxClient } from "@toolbox-sdk/core";
|
||||
import { genkit } from "genkit";
|
||||
import { vertexAI } from "@genkit-ai/vertexai";
|
||||
import { z } from "zod";
|
||||
|
||||
const toolboxClient = new ToolboxClient("http://127.0.0.1:5000");
|
||||
|
||||
const ai = genkit({
|
||||
plugins: [vertexAI({ location: "us-central1", projectId: process.env.PROJECT_ID })],
|
||||
});
|
||||
|
||||
const systemPrompt = `
|
||||
You're a helpful hotel assistant. You handle hotel searching, booking and cancellations.
|
||||
When the user searches for a hotel, mention its name, ID, location and price tier.
|
||||
Always mention hotel ID while performing any operations. This is very important for any operations.
|
||||
For any bookings or cancellations, please provide the appropriate confirmation.
|
||||
Be sure to update checkin or checkout dates if mentioned by the user.
|
||||
Don't ask for confirmations from the user.
|
||||
`;
|
||||
|
||||
const queries = [
|
||||
"Find hotels in Bern with Bern in it's name.",
|
||||
"Please book the hotel Best Western Bern for me.",
|
||||
"This is too expensive. Please cancel it.",
|
||||
"Please book Comfort Inn Bern for me",
|
||||
"My check in dates for my booking would be from April 10, 2024 to April 19, 2024.",
|
||||
];
|
||||
|
||||
async function run() {
|
||||
let tools;
|
||||
try {
|
||||
tools = await toolboxClient.loadToolset("my-toolset");
|
||||
} catch {
|
||||
return;
|
||||
}
|
||||
|
||||
const toolboxTools = await toolboxClient.loadToolset("my-toolset");
|
||||
|
||||
const toolMap = {};
|
||||
for (const tool of toolboxTools) {
|
||||
let inputSchema;
|
||||
switch (tool.getName()) {
|
||||
case "search-hotels-by-name":
|
||||
inputSchema = z.object({ name: z.string() });
|
||||
break;
|
||||
case "search-hotels-by-location":
|
||||
inputSchema = z.object({ location: z.string() });
|
||||
break;
|
||||
case "book-hotel":
|
||||
inputSchema = z.object({ hotel_id: z.string() });
|
||||
break;
|
||||
case "update-hotel":
|
||||
inputSchema = z.object({
|
||||
hotel_id: z.string(),
|
||||
checkin_date: z.string(),
|
||||
checkout_date: z.string(),
|
||||
});
|
||||
break;
|
||||
case "cancel-hotel":
|
||||
inputSchema = z.object({ hotel_id: z.string() });
|
||||
break;
|
||||
default:
|
||||
inputSchema = z.object({});
|
||||
}
|
||||
|
||||
const definedTool = ai.defineTool(
|
||||
{
|
||||
name: tool.getName(),
|
||||
description: tool.getDescription(),
|
||||
inputSchema,
|
||||
},
|
||||
tool
|
||||
);
|
||||
|
||||
toolMap[tool.getName()] = definedTool;
|
||||
}
|
||||
|
||||
let conversationHistory = [
|
||||
{
|
||||
role: "system",
|
||||
content: [{ text: systemPrompt }],
|
||||
},
|
||||
];
|
||||
|
||||
for (const userQuery of queries) {
|
||||
console.log(`\n👤 User: "${userQuery}"`);
|
||||
conversationHistory.push({
|
||||
role: "user",
|
||||
content: [{ text: userQuery }],
|
||||
});
|
||||
|
||||
const response = await ai.generate({
|
||||
model: vertexAI.model("gemini-2.5-flash"),
|
||||
messages: conversationHistory,
|
||||
tools: Object.values(toolMap),
|
||||
});
|
||||
|
||||
let content = [],
|
||||
functionCalls = [];
|
||||
|
||||
if (response.toolRequests?.length) {
|
||||
functionCalls = response.toolRequests;
|
||||
content = response.content || [{ text: response.text || "" }];
|
||||
} else if (response.candidates?.length) {
|
||||
content = response.candidates[0].content;
|
||||
functionCalls = content.filter((part) => part.functionCall);
|
||||
} else {
|
||||
content = [{ text: response.text || response.output || "No response text found" }];
|
||||
}
|
||||
|
||||
conversationHistory.push({
|
||||
role: "model",
|
||||
content,
|
||||
});
|
||||
|
||||
if (functionCalls.length > 0) {
|
||||
for (const call of functionCalls) {
|
||||
const toolName =
|
||||
call.functionCall?.name || call.toolRequest?.name || call.name;
|
||||
const toolArgs =
|
||||
call.functionCall?.args || call.toolRequest?.input || call.input;
|
||||
|
||||
const tool = toolMap[toolName];
|
||||
if (!tool) continue;
|
||||
|
||||
try {
|
||||
const toolResponse = await tool.invoke(toolArgs);
|
||||
|
||||
conversationHistory.push({
|
||||
role: "function",
|
||||
content: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: toolName,
|
||||
response: toolResponse,
|
||||
},
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
if (toolName.includes("search-hotels")) {
|
||||
if (Array.isArray(toolResponse) && toolResponse.length > 0) {
|
||||
const hotelList = toolResponse
|
||||
.map(
|
||||
(h) =>
|
||||
`* Hotel Name: ${h.name}, ID: ${h.hotel_id}, Location: ${h.location}, Price Tier: ${h.price_tier}`
|
||||
)
|
||||
.join("\n");
|
||||
console.log(`🤖 Hotel Agent: I found these hotels in Bern:\n\n${hotelList}`);
|
||||
} else {
|
||||
console.log("🤖 Hotel Agent: No hotels found.");
|
||||
}
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const finalResponse = await ai.generate({
|
||||
model: vertexAI.model("gemini-2.5-flash"),
|
||||
messages: conversationHistory,
|
||||
tools: Object.values(toolMap),
|
||||
});
|
||||
|
||||
const finalMessage =
|
||||
finalResponse.text || finalResponse.output || "No final response";
|
||||
|
||||
conversationHistory.push({
|
||||
role: "model",
|
||||
content: [{ text: finalMessage }],
|
||||
});
|
||||
|
||||
console.log(`🤖 Hotel Agent: ${finalMessage}`);
|
||||
} else {
|
||||
const message =
|
||||
response.text || response.output || content[0]?.text || "No response";
|
||||
console.log(`🤖 Hotel Agent: ${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
run();
|
||||
{{< /tab >}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
6. Make sure your Toolbox server is running (`./toolbox --tools-file "tools.yaml"`), then run your agent [make sure you are in the root directory]:
|
||||
|
||||
```sh
|
||||
node index.js
|
||||
```
|
||||
@@ -667,7 +667,7 @@ Documentation](https://github.com/googleapis/python-genai?tab=readme-ov-file#man
|
||||
{{% /tab %}}
|
||||
{{< /tabpane >}}
|
||||
|
||||
1. Run your agent, and observe the results:
|
||||
4. Run your agent, and observe the results:
|
||||
|
||||
```sh
|
||||
python hotel_agent.py
|
||||
@@ -133,22 +133,16 @@ section.
|
||||
|
||||
## Connecting with Toolbox Client SDK
|
||||
|
||||
You can connect to Toolbox Cloud Run instances directly through the SDK
|
||||
You can connect to Toolbox Cloud Run instances directly through the SDK.
|
||||
|
||||
1. [Set up `Cloud Run Invoker` role
|
||||
access](https://cloud.google.com/run/docs/securing/managing-access#service-add-principals)
|
||||
to your Cloud Run service.
|
||||
|
||||
1. Set up [Application Default
|
||||
1. (Only for local runs) Set up [Application Default
|
||||
Credentials](https://cloud.google.com/docs/authentication/set-up-adc-local-dev-environment)
|
||||
for the principle you set up the `Cloud Run Invoker` role access to.
|
||||
|
||||
{{< notice tip >}}
|
||||
If you're working in some other environment than local, set up [environment
|
||||
specific Default
|
||||
Credentials](https://cloud.google.com/docs/authentication/provide-credentials-adc).
|
||||
{{< /notice >}}
|
||||
|
||||
1. Run the following to retrieve a non-deterministic URL for the cloud run service:
|
||||
|
||||
```bash
|
||||
@@ -160,9 +154,11 @@ You can connect to Toolbox Cloud Run instances directly through the SDK
|
||||
```python
|
||||
from toolbox_core import ToolboxClient, auth_methods
|
||||
|
||||
auth_token_provider = auth_methods.aget_google_id_token # can also use sync method
|
||||
|
||||
# Replace with the Cloud Run service URL generated in the previous step.
|
||||
URL = "https://cloud-run-url.app"
|
||||
|
||||
auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method
|
||||
|
||||
async with ToolboxClient(
|
||||
URL,
|
||||
client_headers={"Authorization": auth_token_provider},
|
||||
|
||||
70
docs/en/resources/sources/firestore.md
Normal file
70
docs/en/resources/sources/firestore.md
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
title: "Firestore"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Firestore is a NoSQL document database built for automatic scaling, high performance, and ease of application development. It's a fully managed, serverless database that supports mobile, web, and server development.
|
||||
|
||||
---
|
||||
|
||||
# Firestore Source
|
||||
|
||||
[Firestore][firestore-docs] is a NoSQL document database built for automatic
|
||||
scaling, high performance, and ease of application development. While the
|
||||
Firestore interface has many of the same features as traditional databases,
|
||||
as a NoSQL database it differs from them in the way it describes relationships
|
||||
between data objects.
|
||||
|
||||
If you are new to Firestore, you can [create a database and learn the
|
||||
basics][firestore-quickstart].
|
||||
|
||||
[firestore-docs]: https://cloud.google.com/firestore/docs
|
||||
[firestore-quickstart]: https://cloud.google.com/firestore/docs/quickstart-servers
|
||||
|
||||
## Requirements
|
||||
|
||||
### IAM Permissions
|
||||
|
||||
Firestore uses [Identity and Access Management (IAM)][iam-overview] to control
|
||||
user and group access to Firestore resources. Toolbox will use your [Application
|
||||
Default Credentials (ADC)][adc] to authorize and authenticate when interacting
|
||||
with [Firestore][firestore-docs].
|
||||
|
||||
In addition to [setting the ADC for your server][set-adc], you need to ensure
|
||||
the IAM identity has been given the correct IAM permissions for accessing
|
||||
Firestore. Common roles include:
|
||||
- `roles/datastore.user` - Read and write access to Firestore
|
||||
- `roles/datastore.viewer` - Read-only access to Firestore
|
||||
|
||||
See [Firestore access control][firestore-iam] for more information on
|
||||
applying IAM permissions and roles to an identity.
|
||||
|
||||
[iam-overview]: https://cloud.google.com/firestore/docs/security/iam
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
[firestore-iam]: https://cloud.google.com/firestore/docs/security/iam
|
||||
|
||||
### Database Selection
|
||||
|
||||
Firestore allows you to create multiple databases within a single project. Each
|
||||
database is isolated from the others and has its own set of documents and
|
||||
collections. If you don't specify a database in your configuration, the default
|
||||
database named `(default)` will be used.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-firestore-source:
|
||||
kind: "firestore"
|
||||
project: "my-project-id"
|
||||
# database: "my-database" # Optional, defaults to "(default)"
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|----------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "firestore". |
|
||||
| project | string | true | Id of the GCP project that contains the Firestore database (e.g. "my-project-id"). |
|
||||
| database | string | false | Name of the Firestore database to connect to. Defaults to "(default)" if not specified. |
|
||||
@@ -81,8 +81,9 @@ the parameter.
|
||||
|-------------|:---------------:|:------------:|-----------------------------------------------------------------------------|
|
||||
| name | string | true | Name of the parameter. |
|
||||
| type | string | true | Must be one of "string", "integer", "float", "boolean" "array" |
|
||||
| default | parameter type | false | Default value of the parameter. If provided, the parameter is not required. |
|
||||
| description | string | true | Natural language description of the parameter to describe it to the agent. |
|
||||
| default | parameter type | false | Default value of the parameter. If provided, `required` will be `false`. |
|
||||
| required | bool | false | Indicate if the parameter is required. Default to `true`. |
|
||||
|
||||
### Array Parameters
|
||||
|
||||
@@ -107,14 +108,47 @@ in the list using the items field:
|
||||
|-------------|:----------------:|:------------:|-----------------------------------------------------------------------------|
|
||||
| name | string | true | Name of the parameter. |
|
||||
| type | string | true | Must be "array" |
|
||||
| default | parameter type | false | Default value of the parameter. If provided, the parameter is not required. |
|
||||
| description | string | true | Natural language description of the parameter to describe it to the agent. |
|
||||
| default | parameter type | false | Default value of the parameter. If provided, `required` will be `false`. |
|
||||
| required | bool | false | Indicate if the parameter is required. Default to `true`. |
|
||||
| items | parameter object | true | Specify a Parameter object for the type of the values in the array. |
|
||||
|
||||
{{< notice note >}}
|
||||
Items in array should not have a default value. If provided, it will be ignored.
|
||||
Items in array should not have a `default` or `required` value. If provided, it will be ignored.
|
||||
{{< /notice >}}
|
||||
|
||||
### Map Parameters
|
||||
|
||||
The map type is a collection of key-value pairs. It can be configured in two ways:
|
||||
|
||||
- Generic Map: By default, it accepts values of any primitive type (string, number, boolean), allowing for mixed data.
|
||||
- Typed Map: By setting the valueType field, you can enforce that all values
|
||||
within the map must be of the same specified type.
|
||||
|
||||
#### Generic Map (Mixed Value Types)
|
||||
|
||||
This is the default behavior when valueType is omitted. It's useful for passing a flexible group of settings.
|
||||
|
||||
```yaml
|
||||
parameters:
|
||||
- name: execution_context
|
||||
type: map
|
||||
description: A flexible set of key-value pairs for the execution environment.
|
||||
```
|
||||
|
||||
#### Typed Map
|
||||
|
||||
Specify valueType to ensure all values in the map are of the same type. An error
|
||||
will be thrown in case of value type mismatch.
|
||||
|
||||
```yaml
|
||||
parameters:
|
||||
- name: user_scores
|
||||
type: map
|
||||
description: A map of user IDs to their scores. All scores must be integers.
|
||||
valueType: integer # This enforces the value type for all entries.
|
||||
```
|
||||
|
||||
### Authenticated Parameters
|
||||
|
||||
Authenticated parameters are automatically populated with user
|
||||
|
||||
7
docs/en/resources/tools/firestore/_index.md
Normal file
7
docs/en/resources/tools/firestore/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Firestore"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Firestore Sources.
|
||||
---
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: "firestore-delete-documents"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "firestore-delete-documents" tool deletes multiple documents from Firestore by their paths.
|
||||
aliases:
|
||||
- /resources/tools/firestore-delete-documents
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `firestore-delete-documents` tool deletes multiple documents from Firestore by their paths.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [firestore](../sources/firestore.md)
|
||||
|
||||
`firestore-delete-documents` takes one input parameter `documentPaths` which is an array of
|
||||
document paths to delete. The tool uses Firestore's BulkWriter for efficient batch deletion
|
||||
and returns the success status for each document.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
delete_user_documents:
|
||||
kind: firestore-delete-documents
|
||||
source: my-firestore-source
|
||||
description: Use this tool to delete multiple documents from Firestore.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "firestore-delete-documents". |
|
||||
| source | string | true | Name of the Firestore source to delete documents from. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
38
docs/en/resources/tools/firestore/firestore-get-documents.md
Normal file
38
docs/en/resources/tools/firestore/firestore-get-documents.md
Normal file
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: "firestore-get-documents"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "firestore-get-documents" tool retrieves multiple documents from Firestore by their paths.
|
||||
aliases:
|
||||
- /resources/tools/firestore-get-documents
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `firestore-get-documents` tool retrieves multiple documents from Firestore by their paths.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [firestore](../sources/firestore.md)
|
||||
|
||||
`firestore-get-documents` takes one input parameter `documentPaths` which is an array of
|
||||
document paths, and returns the documents' data along with metadata such as existence status,
|
||||
creation time, update time, and read time.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_user_documents:
|
||||
kind: firestore-get-documents
|
||||
source: my-firestore-source
|
||||
description: Use this tool to retrieve multiple documents from Firestore.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "firestore-get-documents". |
|
||||
| source | string | true | Name of the Firestore source to retrieve documents from. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
37
docs/en/resources/tools/firestore/firestore-get-rules.md
Normal file
37
docs/en/resources/tools/firestore/firestore-get-rules.md
Normal file
@@ -0,0 +1,37 @@
|
||||
---
|
||||
title: "firestore-get-rules"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "firestore-get-rules" tool retrieves the active Firestore security rules for the current project.
|
||||
aliases:
|
||||
- /resources/tools/firestore-get-rules
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `firestore-get-rules` tool retrieves the active [Firestore security rules](https://firebase.google.com/docs/firestore/security/get-started) for the current project.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [firestore](../sources/firestore.md)
|
||||
|
||||
`firestore-get-rules` takes no input parameters and returns the security rules content along with metadata
|
||||
such as the ruleset name, and timestamps.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_firestore_rules:
|
||||
kind: firestore-get-rules
|
||||
source: my-firestore-source
|
||||
description: Use this tool to retrieve the active Firestore security rules.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "firestore-get-rules". |
|
||||
| source | string | true | Name of the Firestore source to retrieve rules from. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -0,0 +1,38 @@
|
||||
---
|
||||
title: "firestore-list-collections"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "firestore-list-collections" tool lists collections in Firestore, either at the root level or as subcollections of a document.
|
||||
aliases:
|
||||
- /resources/tools/firestore-list-collections
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `firestore-list-collections` tool lists [collections](https://firebase.google.com/docs/firestore/data-model#collections) in Firestore, either at the root level or as [subcollections](https://firebase.google.com/docs/firestore/data-model#subcollections) of a specific document.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [firestore](../sources/firestore.md)
|
||||
|
||||
`firestore-list-collections` takes an optional `parentPath` parameter to specify a document
|
||||
path. If provided, it lists all subcollections of that document. If not provided, it lists
|
||||
all root-level collections in the database.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_firestore_collections:
|
||||
kind: firestore-list-collections
|
||||
source: my-firestore-source
|
||||
description: Use this tool to list collections in Firestore.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "firestore-list-collections". |
|
||||
| source | string | true | Name of the Firestore source to list collections from. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
198
docs/en/resources/tools/firestore/firestore-query-collection.md
Normal file
198
docs/en/resources/tools/firestore/firestore-query-collection.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# firestore-query-collection
|
||||
|
||||
The `firestore-query-collection` tool allows you to query Firestore collections with filters, ordering, and limit capabilities.
|
||||
|
||||
## Configuration
|
||||
|
||||
To use this tool, you need to configure it in your YAML configuration file:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
config:
|
||||
project: my-gcp-project
|
||||
database: "(default)"
|
||||
|
||||
tools:
|
||||
query_collection:
|
||||
kind: firestore-query-collection
|
||||
source: my-firestore
|
||||
description: Query Firestore collections with advanced filtering
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Type | Required | Default | Description |
|
||||
|-----------|------|----------|---------|-------------|
|
||||
| `collectionPath` | string | Yes | - | The path to the Firestore collection to query |
|
||||
| `filters` | array | No | - | Array of filter objects (as JSON strings) to apply to the query |
|
||||
| `orderBy` | string | No | - | JSON string specifying field and direction to order results |
|
||||
| `limit` | integer | No | 100 | Maximum number of documents to return |
|
||||
| `analyzeQuery` | boolean | No | false | If true, returns query explain metrics including execution statistics |
|
||||
|
||||
### Filter Format
|
||||
|
||||
Each filter in the `filters` array should be a JSON string with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"field": "fieldName",
|
||||
"op": "operator",
|
||||
"value": "compareValue"
|
||||
}
|
||||
```
|
||||
|
||||
Supported operators:
|
||||
- `<` - Less than
|
||||
- `<=` - Less than or equal to
|
||||
- `>` - Greater than
|
||||
- `>=` - Greater than or equal to
|
||||
- `==` - Equal to
|
||||
- `!=` - Not equal to
|
||||
- `array-contains` - Array contains a specific value
|
||||
- `array-contains-any` - Array contains any of the specified values
|
||||
- `in` - Field value is in the specified array
|
||||
- `not-in` - Field value is not in the specified array
|
||||
|
||||
Value types supported:
|
||||
- String: `"value": "text"`
|
||||
- Number: `"value": 123` or `"value": 45.67`
|
||||
- Boolean: `"value": true` or `"value": false`
|
||||
- Array: `"value": ["item1", "item2"]` (for `in`, `not-in`, `array-contains-any` operators)
|
||||
|
||||
### OrderBy Format
|
||||
|
||||
The `orderBy` parameter should be a JSON string with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"field": "fieldName",
|
||||
"direction": "ASCENDING"
|
||||
}
|
||||
```
|
||||
|
||||
Direction values:
|
||||
- `ASCENDING`
|
||||
- `DESCENDING`
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Query with filters
|
||||
|
||||
```json
|
||||
{
|
||||
"collectionPath": "users",
|
||||
"filters": [
|
||||
"{\"field\": \"age\", \"op\": \">\", \"value\": 18}",
|
||||
"{\"field\": \"status\", \"op\": \"==\", \"value\": \"active\"}"
|
||||
],
|
||||
"orderBy": "{\"field\": \"createdAt\", \"direction\": \"DESCENDING\"}",
|
||||
"limit": 50
|
||||
}
|
||||
```
|
||||
|
||||
### Query with array contains filter
|
||||
|
||||
```json
|
||||
{
|
||||
"collectionPath": "products",
|
||||
"filters": [
|
||||
"{\"field\": \"categories\", \"op\": \"array-contains\", \"value\": \"electronics\"}",
|
||||
"{\"field\": \"price\", \"op\": \"<\", \"value\": 1000}"
|
||||
],
|
||||
"orderBy": "{\"field\": \"price\", \"direction\": \"ASCENDING\"}",
|
||||
"limit": 20
|
||||
}
|
||||
```
|
||||
|
||||
### Query with IN operator
|
||||
|
||||
```json
|
||||
{
|
||||
"collectionPath": "orders",
|
||||
"filters": [
|
||||
"{\"field\": \"status\", \"op\": \"in\", \"value\": [\"pending\", \"processing\"]}"
|
||||
],
|
||||
"limit": 100
|
||||
}
|
||||
```
|
||||
|
||||
### Query with explain metrics
|
||||
|
||||
```json
|
||||
{
|
||||
"collectionPath": "users",
|
||||
"filters": [
|
||||
"{\"field\": \"age\", \"op\": \">=\", \"value\": 21}",
|
||||
"{\"field\": \"active\", \"op\": \"==\", \"value\": true}"
|
||||
],
|
||||
"orderBy": "{\"field\": \"lastLogin\", \"direction\": \"DESCENDING\"}",
|
||||
"limit": 25,
|
||||
"analyzeQuery": true
|
||||
}
|
||||
```
|
||||
|
||||
## Response Format
|
||||
|
||||
### Standard Response (analyzeQuery = false)
|
||||
|
||||
The tool returns an array of documents, where each document includes:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "documentId",
|
||||
"path": "collection/documentId",
|
||||
"data": {
|
||||
// Document fields
|
||||
},
|
||||
"createTime": "2025-01-07T12:00:00Z",
|
||||
"updateTime": "2025-01-07T12:00:00Z",
|
||||
"readTime": "2025-01-07T12:00:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### Response with Query Analysis (analyzeQuery = true)
|
||||
|
||||
When `analyzeQuery` is set to true, the tool returns a single object containing documents and explain metrics:
|
||||
|
||||
```json
|
||||
{
|
||||
"documents": [
|
||||
// Array of document objects as shown above
|
||||
],
|
||||
"explainMetrics": {
|
||||
"planSummary": {
|
||||
"indexesUsed": [
|
||||
{
|
||||
"query_scope": "Collection",
|
||||
"properties": "(field ASC, __name__ ASC)"
|
||||
}
|
||||
]
|
||||
},
|
||||
"executionStats": {
|
||||
"resultsReturned": 50,
|
||||
"readOperations": 50,
|
||||
"executionDuration": "120ms",
|
||||
"debugStats": {
|
||||
"indexes_entries_scanned": "1000",
|
||||
"documents_scanned": "50",
|
||||
"billing_details": {
|
||||
"documents_billable": "50",
|
||||
"index_entries_billable": "1000",
|
||||
"min_query_cost": "0"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool will return errors for:
|
||||
- Invalid collection path
|
||||
- Malformed filter JSON
|
||||
- Unsupported operators
|
||||
- Query execution failures
|
||||
- Invalid orderBy format
|
||||
115
docs/en/resources/tools/firestore/firestore-validate-rules.md
Normal file
115
docs/en/resources/tools/firestore/firestore-validate-rules.md
Normal file
@@ -0,0 +1,115 @@
|
||||
---
|
||||
title: firestore-validate-rules
|
||||
weight: 6
|
||||
date: 2025-01-07
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The `firestore-validate-rules` tool validates Firestore security rules syntax and semantic correctness without deploying them. It provides detailed error reporting with source positions and code snippets.
|
||||
|
||||
## Configuration
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
firestore-validate-rules:
|
||||
kind: firestore-validate-rules
|
||||
source: <firestore-source-name>
|
||||
description: "Checks the provided Firestore Rules source for syntax and validation errors"
|
||||
```
|
||||
|
||||
## Authentication
|
||||
|
||||
This tool requires authentication if the source requires authentication.
|
||||
|
||||
## Parameters
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
|-----------|--------|----------|-------------|
|
||||
| source | string | Yes | The Firestore Rules source code to validate |
|
||||
|
||||
## Response
|
||||
|
||||
The tool returns a `ValidationResult` object containing:
|
||||
|
||||
```json
|
||||
{
|
||||
"valid": boolean, // Whether the rules are valid
|
||||
"issueCount": number, // Number of issues found
|
||||
"formattedIssues": string, // Human-readable formatted issues
|
||||
"rawIssues": [ // Array of raw issue objects
|
||||
{
|
||||
"sourcePosition": {
|
||||
"fileName": string,
|
||||
"line": number,
|
||||
"column": number,
|
||||
"currentOffset": number,
|
||||
"endOffset": number
|
||||
},
|
||||
"description": string,
|
||||
"severity": string // e.g., "ERROR", "WARNING"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Example Usage
|
||||
|
||||
### Validate simple rules
|
||||
|
||||
```json
|
||||
{
|
||||
"source": "rules_version = '2';\nservice cloud.firestore {\n match /databases/{database}/documents {\n match /{document=**} {\n allow read, write: if true;\n }\n }\n}"
|
||||
}
|
||||
```
|
||||
|
||||
### Example response for valid rules
|
||||
|
||||
```json
|
||||
{
|
||||
"valid": true,
|
||||
"issueCount": 0,
|
||||
"formattedIssues": "✓ No errors detected. Rules are valid."
|
||||
}
|
||||
```
|
||||
|
||||
### Example response with errors
|
||||
|
||||
```json
|
||||
{
|
||||
"valid": false,
|
||||
"issueCount": 1,
|
||||
"formattedIssues": "Found 1 issue(s) in rules source:\n\nERROR: Unexpected token ';' [Ln 4, Col 32]\n```\n allow read, write: if true;;\n ^\n```",
|
||||
"rawIssues": [
|
||||
{
|
||||
"sourcePosition": {
|
||||
"line": 4,
|
||||
"column": 32,
|
||||
"currentOffset": 105,
|
||||
"endOffset": 106
|
||||
},
|
||||
"description": "Unexpected token ';'",
|
||||
"severity": "ERROR"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool will return errors for:
|
||||
- Missing or empty `source` parameter
|
||||
- API errors when calling the Firebase Rules service
|
||||
- Network connectivity issues
|
||||
|
||||
## Use Cases
|
||||
|
||||
1. **Pre-deployment validation**: Validate rules before deploying to production
|
||||
2. **CI/CD integration**: Integrate rules validation into your build pipeline
|
||||
3. **Development workflow**: Quickly check rules syntax while developing
|
||||
4. **Error debugging**: Get detailed error locations with code snippets
|
||||
|
||||
## Related Tools
|
||||
|
||||
- [firestore-get-rules]({{< ref "firestore-get-rules" >}}): Retrieve current active rules
|
||||
- [firestore-query-collection]({{< ref "firestore-query-collection" >}}): Test rules by querying collections
|
||||
5
go.mod
5
go.mod
@@ -9,10 +9,11 @@ require (
|
||||
cloud.google.com/go/bigquery v1.69.0
|
||||
cloud.google.com/go/bigtable v1.38.0
|
||||
cloud.google.com/go/cloudsqlconn v1.17.3
|
||||
cloud.google.com/go/firestore v1.18.0
|
||||
cloud.google.com/go/spanner v1.83.0
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.29.0
|
||||
github.com/couchbase/gocb/v2 v2.10.0
|
||||
github.com/couchbase/gocb/v2 v2.10.1
|
||||
github.com/couchbase/tools-common/http v1.0.9
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/go-chi/chi/v5 v5.2.2
|
||||
@@ -65,7 +66,7 @@ require (
|
||||
github.com/cenkalti/backoff/v5 v5.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f // indirect
|
||||
github.com/couchbase/gocbcore/v10 v10.7.0 // indirect
|
||||
github.com/couchbase/gocbcore/v10 v10.7.1 // indirect
|
||||
github.com/couchbase/gocbcoreps v0.1.3 // indirect
|
||||
github.com/couchbase/goprotostellar v1.0.2 // indirect
|
||||
github.com/couchbase/tools-common/errors v1.0.0 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -293,6 +293,8 @@ cloud.google.com/go/filestore v1.4.0/go.mod h1:PaG5oDfo9r224f8OYXURtAsY+Fbyq/bLY
|
||||
cloud.google.com/go/filestore v1.5.0/go.mod h1:FqBXDWBp4YLHqRnVGveOkHDf8svj9r5+mUDLupOWEDs=
|
||||
cloud.google.com/go/filestore v1.6.0/go.mod h1:di5unNuss/qfZTw2U9nhFqo8/ZDSc466dre85Kydllg=
|
||||
cloud.google.com/go/firestore v1.9.0/go.mod h1:HMkjKHNTtRyZNiMzu7YAsLr9K3X2udY2AMwDaMEQiiE=
|
||||
cloud.google.com/go/firestore v1.18.0 h1:cuydCaLS7Vl2SatAeivXyhbhDEIR8BDmtn4egDhIn2s=
|
||||
cloud.google.com/go/firestore v1.18.0/go.mod h1:5ye0v48PhseZBdcl0qbl3uttu7FIEwEYVaWm0UIEOEU=
|
||||
cloud.google.com/go/functions v1.6.0/go.mod h1:3H1UA3qiIPRWD7PeZKLvHZ9SaQhR26XIJcC0A5GbvAk=
|
||||
cloud.google.com/go/functions v1.7.0/go.mod h1:+d+QBcWM+RsrgZfV9xo6KfA1GlzJfxcfZcRPEhDDfzg=
|
||||
cloud.google.com/go/functions v1.8.0/go.mod h1:RTZ4/HsQjIqIYP9a9YPbU+QFoQsAlYgrwOXJWHn1POY=
|
||||
@@ -710,10 +712,10 @@ github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWH
|
||||
github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs=
|
||||
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f h1:C5bqEmzEPLsHm9Mv73lSE9e9bKV23aB1vxOsmZrkl3k=
|
||||
github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8=
|
||||
github.com/couchbase/gocb/v2 v2.10.0 h1:NNxZ4okToU1Ylqp6F8tE41CEJQPhb2WjufryAkeubOk=
|
||||
github.com/couchbase/gocb/v2 v2.10.0/go.mod h1:OSbMfQkP7ltbKiDZhsT2mGDhkQNmvGXxptKcxAUJQ2Y=
|
||||
github.com/couchbase/gocbcore/v10 v10.7.0 h1:lAEi0PNeEGKOu8pWrPUdtLOT2oGr1J/UTdGHVPC3r/0=
|
||||
github.com/couchbase/gocbcore/v10 v10.7.0/go.mod h1:Q8JWVenMCEOuRgrDQKApHbzzPif38HzefGgRVe9apAI=
|
||||
github.com/couchbase/gocb/v2 v2.10.1 h1:5r1jngGxw3dTZdtq6Xmjq3pdU6hOwRvynvbVIp58T64=
|
||||
github.com/couchbase/gocb/v2 v2.10.1/go.mod h1:GGEJuYjrfnPHCQLcxTcIco+Puy63PS2p8QQd8FRw66I=
|
||||
github.com/couchbase/gocbcore/v10 v10.7.1 h1:6jsNDtqyfoQ8Xg6kv99rzccc3CrHbp7FjeY+ahWXTF4=
|
||||
github.com/couchbase/gocbcore/v10 v10.7.1/go.mod h1:Q8JWVenMCEOuRgrDQKApHbzzPif38HzefGgRVe9apAI=
|
||||
github.com/couchbase/gocbcoreps v0.1.3 h1:fILaKGCjxFIeCgAUG8FGmRDSpdrRggohOMKEgO9CUpg=
|
||||
github.com/couchbase/gocbcoreps v0.1.3/go.mod h1:hBFpDNPnRno6HH5cRXExhqXYRmTsFJlFHQx7vztcXPk=
|
||||
github.com/couchbase/goprotostellar v1.0.2 h1:yoPbAL9sCtcyZ5e/DcU5PRMOEFaJrF9awXYu3VPfGls=
|
||||
|
||||
@@ -28,6 +28,7 @@ func TestLoadPrebuiltToolYAMLs(t *testing.T) {
|
||||
"cloud-sql-mssql",
|
||||
"cloud-sql-mysql",
|
||||
"cloud-sql-postgres",
|
||||
"firestore",
|
||||
"postgres",
|
||||
"spanner-postgres",
|
||||
"spanner",
|
||||
@@ -68,6 +69,7 @@ func TestGetPrebuiltTool(t *testing.T) {
|
||||
cloudsqlpg_config, _ := Get("cloud-sql-postgres")
|
||||
cloudsqlmysql_config, _ := Get("cloud-sql-mysql")
|
||||
cloudsqlmssql_config, _ := Get("cloud-sql-mssql")
|
||||
firestoreconfig, _ := Get("firestore")
|
||||
postgresconfig, _ := Get("postgres")
|
||||
spanner_config, _ := Get("spanner")
|
||||
spannerpg_config, _ := Get("spanner-postgres")
|
||||
@@ -86,6 +88,9 @@ func TestGetPrebuiltTool(t *testing.T) {
|
||||
if len(cloudsqlmssql_config) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch cloud sql mssql prebuilt tools yaml")
|
||||
}
|
||||
if len(firestoreconfig) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch firestore prebuilt tools yaml")
|
||||
}
|
||||
if len(postgresconfig) <= 0 {
|
||||
t.Fatalf("unexpected error: could not fetch postgres prebuilt tools yaml")
|
||||
}
|
||||
|
||||
@@ -154,9 +154,10 @@ tools:
|
||||
) USING utf8mb4) AS object_details
|
||||
FROM
|
||||
INFORMATION_SCHEMA.TABLES T
|
||||
CROSS JOIN (SELECT @table_names := ?) AS variables
|
||||
WHERE
|
||||
T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys')
|
||||
AND (NULLIF(TRIM(?), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, ?))
|
||||
AND (NULLIF(TRIM(@table_names), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, @table_names))
|
||||
AND T.TABLE_TYPE = 'BASE TABLE'
|
||||
ORDER BY
|
||||
T.TABLE_SCHEMA, T.TABLE_NAME;
|
||||
@@ -164,6 +165,7 @@ tools:
|
||||
- name: table_names
|
||||
type: string
|
||||
description: "Optional: A comma-separated list of table names. If empty, details for all tables in user-accessible schemas will be listed."
|
||||
default: ""
|
||||
toolsets:
|
||||
cloud-sql-mysql-database-tools:
|
||||
- execute_sql
|
||||
|
||||
42
internal/prebuiltconfigs/tools/firestore.yaml
Normal file
42
internal/prebuiltconfigs/tools/firestore.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
sources:
|
||||
firestore-source:
|
||||
kind: firestore
|
||||
project: ${FIRESTORE_PROJECT}
|
||||
database: ${FIRESTORE_DATABASE} # Optional, defaults to "(default)" if not specified
|
||||
|
||||
tools:
|
||||
firestore-get-documents:
|
||||
kind: firestore-get-documents
|
||||
source: firestore-source
|
||||
description: Gets multiple documents from Firestore by their paths
|
||||
firestore-list-collections:
|
||||
kind: firestore-list-collections
|
||||
source: firestore-source
|
||||
description: List Firestore collections for a given parent path
|
||||
firestore-delete-documents:
|
||||
kind: firestore-delete-documents
|
||||
source: firestore-source
|
||||
description: Delete multiple documents from Firestore
|
||||
firestore-query-collection:
|
||||
kind: firestore-query-collection
|
||||
source: firestore-source
|
||||
description: |
|
||||
Retrieves one or more Firestore documents from a collection in a database in the current project by a collection with a full document path.
|
||||
Use this if you know the exact path of a collection and the filtering clause you would like for the document.
|
||||
firestore-get-rules:
|
||||
kind: firestore-get-rules
|
||||
source: firestore-source
|
||||
description: Retrieves the active Firestore security rules for the current project
|
||||
firestore-validate-rules:
|
||||
kind: firestore-validate-rules
|
||||
source: firestore-source
|
||||
description: Checks the provided Firestore Rules source for syntax and validation errors. Provide the source code to validate.
|
||||
|
||||
toolsets:
|
||||
firestore-database-tools:
|
||||
- firestore-get-documents
|
||||
- firestore-list-collections
|
||||
- firestore-delete-documents
|
||||
- firestore-query-collection
|
||||
- firestore-get-rules
|
||||
- firestore-validate-rules
|
||||
@@ -42,7 +42,7 @@ type MockTool struct {
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t MockTool) Invoke(context.Context, tools.ParamValues) ([]any, error) {
|
||||
func (t MockTool) Invoke(context.Context, tools.ParamValues) (any, error) {
|
||||
mock := []any{t.Name}
|
||||
return mock, nil
|
||||
}
|
||||
|
||||
@@ -122,7 +122,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range results {
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
if !ok {
|
||||
sliceRes = []any{results}
|
||||
}
|
||||
|
||||
for _, d := range sliceRes {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
|
||||
@@ -122,7 +122,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range results {
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
if !ok {
|
||||
sliceRes = []any{results}
|
||||
}
|
||||
|
||||
for _, d := range sliceRes {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
|
||||
@@ -122,7 +122,13 @@ func toolsCallHandler(ctx context.Context, id jsonrpc.RequestId, tools map[strin
|
||||
}
|
||||
|
||||
content := make([]TextContent, 0)
|
||||
for _, d := range results {
|
||||
|
||||
sliceRes, ok := results.([]any)
|
||||
if !ok {
|
||||
sliceRes = []any{results}
|
||||
}
|
||||
|
||||
for _, d := range sliceRes {
|
||||
text := TextContent{Type: "text"}
|
||||
dM, err := json.Marshal(d)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2/google"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
@@ -61,15 +62,17 @@ func (r Config) SourceConfigKind() string {
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a BigQuery Google SQL source
|
||||
client, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
client, restService, err := initBigQueryConnection(ctx, tracer, r.Name, r.Project, r.Location)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
Location: r.Location,
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
RestService: restService,
|
||||
Location: r.Location,
|
||||
}
|
||||
return s, nil
|
||||
|
||||
@@ -79,10 +82,11 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// BigQuery Google SQL struct with client
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *bigqueryapi.Client
|
||||
Location string `yaml:"location"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
Location string `yaml:"location"`
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
@@ -94,30 +98,42 @@ func (s *Source) BigQueryClient() *bigqueryapi.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) BigQueryRestService() *bigqueryrestapi.Service {
|
||||
return s.RestService
|
||||
}
|
||||
|
||||
func initBigQueryConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
name string,
|
||||
project string,
|
||||
location string,
|
||||
) (*bigqueryapi.Client, error) {
|
||||
) (*bigqueryapi.Client, *bigqueryrestapi.Service, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
cred, err := google.FindDefaultCredentials(ctx, bigqueryapi.Scope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
return nil, nil, fmt.Errorf("failed to find default Google Cloud credentials with scope %q: %w", bigqueryapi.Scope, err)
|
||||
}
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Initialize the high-level BigQuery client
|
||||
client, err := bigqueryapi.NewClient(ctx, project, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
client.Location = location
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery client for project %q: %w", project, err)
|
||||
}
|
||||
return client, nil
|
||||
client.Location = location
|
||||
|
||||
// Initialize the low-level BigQuery REST service using the same credentials
|
||||
restService, err := bigqueryrestapi.NewService(ctx, option.WithUserAgent(userAgent), option.WithCredentials(cred))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create BigQuery v2 service: %w", err)
|
||||
}
|
||||
|
||||
return client, restService, nil
|
||||
}
|
||||
|
||||
153
internal/sources/firestore/firestore.go
Normal file
153
internal/sources/firestore/firestore.go
Normal file
@@ -0,0 +1,153 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/firestore"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "firestore"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
// Firestore configs
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Project string `yaml:"project" validate:"required"`
|
||||
Database string `yaml:"database"` // Optional, defaults to "(default)"
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a Firestore source
|
||||
client, err := initFirestoreConnection(ctx, tracer, r.Name, r.Project, r.Database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize Firebase Rules client
|
||||
rulesClient, err := initFirebaseRulesConnection(ctx, r.Project)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Firebase Rules client: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
RulesClient: rulesClient,
|
||||
ProjectId: r.Project,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// Firestore struct with client
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *firestore.Client
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string `yaml:"projectId"`
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Firestore source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) FirestoreClient() *firestore.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) FirebaseRulesClient() *firebaserules.Service {
|
||||
return s.RulesClient
|
||||
}
|
||||
|
||||
func (s *Source) GetProjectId() string {
|
||||
return s.ProjectId
|
||||
}
|
||||
|
||||
func initFirestoreConnection(
|
||||
ctx context.Context,
|
||||
tracer trace.Tracer,
|
||||
name string,
|
||||
project string,
|
||||
database string,
|
||||
) (*firestore.Client, error) {
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
userAgent, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If database is not specified, use the default database
|
||||
if database == "" {
|
||||
database = "(default)"
|
||||
}
|
||||
|
||||
// Create the Firestore client
|
||||
client, err := firestore.NewClientWithDatabase(ctx, project, database, option.WithUserAgent(userAgent))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Firestore client for project %q and database %q: %w", project, database, err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func initFirebaseRulesConnection(
|
||||
ctx context.Context,
|
||||
project string,
|
||||
) (*firebaserules.Service, error) {
|
||||
// Create the Firebase Rules client
|
||||
rulesClient, err := firebaserules.NewService(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Firebase Rules client for project %q: %w", project, err)
|
||||
}
|
||||
|
||||
return rulesClient, nil
|
||||
}
|
||||
130
internal/sources/firestore/firestore_test.go
Normal file
130
internal/sources/firestore/firestore_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestore_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestore(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example with default database",
|
||||
in: `
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Kind: firestore.SourceKind,
|
||||
Project: "my-project",
|
||||
Database: "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with custom database",
|
||||
in: `
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
database: my-database
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-firestore": firestore.Config{
|
||||
Name: "my-firestore",
|
||||
Kind: firestore.SourceKind,
|
||||
Project: "my-project",
|
||||
Database: "my-database",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlFirestore(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
project: my-project
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-firestore\" as \"firestore\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | kind: firestore\n 3 | project: my-project",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-firestore:
|
||||
kind: firestore
|
||||
database: my-database
|
||||
`,
|
||||
err: "unable to parse source \"my-firestore\" as \"firestore\": Key: 'Config.Project' Error:Field validation for 'Project' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -156,7 +156,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
allParamValues := make([]any, len(sliceParams)+1)
|
||||
allParamValues[0] = fmt.Sprintf("%s", sliceParams[0]) // nl_question
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
@@ -44,6 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -95,6 +97,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -110,29 +113,58 @@ type Tool struct {
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query validation failed during dry run: %w", err)
|
||||
}
|
||||
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
// JobStatistics.QueryStatistics.StatementType
|
||||
query := t.Client.Query(sql)
|
||||
query.Location = t.Client.Location
|
||||
|
||||
// This block handles Data Manipulation Language (DML) and Data Definition Language (DDL) statements.
|
||||
// These statements (e.g., INSERT, UPDATE, CREATE TABLE) do not return a row set.
|
||||
// Instead, we execute them as a job, wait for completion, and return a success
|
||||
// message, including the number of affected rows for DML operations.
|
||||
if statementType != "SELECT" {
|
||||
job, err := query.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start DML/DDL job: %w", err)
|
||||
}
|
||||
status, err := job.Wait(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for DML/DDL job to complete: %w", err)
|
||||
}
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("DML/DDL job failed with error: %w", err)
|
||||
}
|
||||
return "Operation completed successfully.", nil
|
||||
}
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
var out []any
|
||||
it, err := query.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err := it.Next(&row)
|
||||
err = it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
@@ -145,7 +177,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if out == nil {
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
@@ -164,3 +198,27 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// dryRunQuery performs a dry run of the SQL query to validate it and get metadata.
|
||||
func dryRunQuery(ctx context.Context, restService *bigqueryrestapi.Service, projectID string, location string, sql string) (*bigqueryrestapi.Job, error) {
|
||||
useLegacySql := false
|
||||
jobToInsert := &bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: projectID,
|
||||
Location: location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: sql,
|
||||
UseLegacySql: &useLegacySql,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
|
||||
}
|
||||
return insertResponse, nil
|
||||
}
|
||||
|
||||
@@ -118,7 +118,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -137,7 +137,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
return nil, fmt.Errorf("failed to get metadata for dataset %s (in project %s): %w", datasetId, t.Client.Project(), err)
|
||||
}
|
||||
|
||||
return []any{metadata}, nil
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -120,7 +120,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
@@ -145,7 +145,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
return nil, fmt.Errorf("failed to get metadata for table %s.%s.%s: %w", projectId, datasetId, tableId, err)
|
||||
}
|
||||
|
||||
return []any{metadata}, nil
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -118,7 +118,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -119,7 +119,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
projectId, ok := mapParams[projectKey].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -124,7 +124,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
|
||||
@@ -161,7 +161,7 @@ func getMapParamsType(tparams tools.Parameters, params tools.ParamValues) (map[s
|
||||
return btParamTypes, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -125,7 +125,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
namedParamsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, namedParamsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -120,7 +120,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMapWithDollarPrefix()
|
||||
|
||||
resp, err := t.DgraphClient.ExecuteQuery(t.Statement, paramsMap, t.IsQuery, t.Timeout)
|
||||
@@ -132,7 +132,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var out []any
|
||||
var result struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
@@ -140,9 +139,8 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("error parsing JSON: %v", err)
|
||||
}
|
||||
out = append(out, result.Data)
|
||||
|
||||
return out, nil
|
||||
return result.Data, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claimsMap map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoredeletedocuments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "firestore-delete-documents"
|
||||
const documentPathsKey string = "documentPaths"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := tools.NewArrayParameter(documentPathsKey, "Array of document paths to delete from Firestore.", tools.NewStringParameter("item", "Document path"))
|
||||
parameters := tools.Parameters{documentPathsParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||
}
|
||||
|
||||
if len(documentPathsRaw) == 0 {
|
||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||
}
|
||||
|
||||
// Use ConvertAnySliceToTyped to convert the slice
|
||||
typedSlice, err := tools.ConvertAnySliceToTyped(documentPathsRaw, "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document paths: %w", err)
|
||||
}
|
||||
|
||||
documentPaths, ok := typedSlice.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type conversion error for document paths")
|
||||
}
|
||||
|
||||
// Create a BulkWriter to handle multiple deletions efficiently
|
||||
bulkWriter := t.Client.BulkWriter(ctx)
|
||||
|
||||
// Keep track of jobs for each document
|
||||
jobs := make([]*firestoreapi.BulkWriterJob, len(documentPaths))
|
||||
|
||||
// Add all delete operations to the BulkWriter
|
||||
for i, path := range documentPaths {
|
||||
docRef := t.Client.Doc(path)
|
||||
job, err := bulkWriter.Delete(docRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to add delete operation for document %q: %w", path, err)
|
||||
}
|
||||
jobs[i] = job
|
||||
}
|
||||
|
||||
// End the BulkWriter to execute all operations
|
||||
bulkWriter.End()
|
||||
|
||||
// Collect results
|
||||
results := make([]any, len(documentPaths))
|
||||
for i, job := range jobs {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
|
||||
// Wait for the job to complete and get the result
|
||||
_, err := job.Results()
|
||||
if err != nil {
|
||||
docData["success"] = false
|
||||
docData["error"] = err.Error()
|
||||
} else {
|
||||
docData["success"] = true
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoredeletedocuments_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoredeletedocuments"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreDeleteDocuments(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
delete_docs_tool:
|
||||
kind: firestore-delete-documents
|
||||
source: my-firestore-instance
|
||||
description: Delete documents from Firestore by paths
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"delete_docs_tool": firestoredeletedocuments.Config{
|
||||
Name: "delete_docs_tool",
|
||||
Kind: "firestore-delete-documents",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "Delete documents from Firestore by paths",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_delete_docs:
|
||||
kind: firestore-delete-documents
|
||||
source: prod-firestore
|
||||
description: Delete documents with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- api-key-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_delete_docs": firestoredeletedocuments.Config{
|
||||
Name: "secure_delete_docs",
|
||||
Kind: "firestore-delete-documents",
|
||||
Source: "prod-firestore",
|
||||
Description: "Delete documents with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "api-key-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
delete_user_docs:
|
||||
kind: firestore-delete-documents
|
||||
source: users-firestore
|
||||
description: Delete user documents
|
||||
authRequired:
|
||||
- user-auth
|
||||
delete_product_docs:
|
||||
kind: firestore-delete-documents
|
||||
source: products-firestore
|
||||
description: Delete product documents
|
||||
delete_order_docs:
|
||||
kind: firestore-delete-documents
|
||||
source: orders-firestore
|
||||
description: Delete order documents
|
||||
authRequired:
|
||||
- user-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"delete_user_docs": firestoredeletedocuments.Config{
|
||||
Name: "delete_user_docs",
|
||||
Kind: "firestore-delete-documents",
|
||||
Source: "users-firestore",
|
||||
Description: "Delete user documents",
|
||||
AuthRequired: []string{"user-auth"},
|
||||
},
|
||||
"delete_product_docs": firestoredeletedocuments.Config{
|
||||
Name: "delete_product_docs",
|
||||
Kind: "firestore-delete-documents",
|
||||
Source: "products-firestore",
|
||||
Description: "Delete product documents",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"delete_order_docs": firestoredeletedocuments.Config{
|
||||
Name: "delete_order_docs",
|
||||
Kind: "firestore-delete-documents",
|
||||
Source: "orders-firestore",
|
||||
Description: "Delete order documents",
|
||||
AuthRequired: []string{"user-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoregetdocuments
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "firestore-get-documents"
|
||||
const documentPathsKey string = "documentPaths"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
documentPathsParameter := tools.NewArrayParameter(documentPathsKey, "Array of document paths to retrieve from Firestore.", tools.NewStringParameter("item", "Document path"))
|
||||
parameters := tools.Parameters{documentPathsParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
documentPathsRaw, ok := mapParams[documentPathsKey].([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter; expected an array", documentPathsKey)
|
||||
}
|
||||
|
||||
if len(documentPathsRaw) == 0 {
|
||||
return nil, fmt.Errorf("'%s' parameter cannot be empty", documentPathsKey)
|
||||
}
|
||||
|
||||
// Use ConvertAnySliceToTyped to convert the slice
|
||||
typedSlice, err := tools.ConvertAnySliceToTyped(documentPathsRaw, "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to convert document paths: %w", err)
|
||||
}
|
||||
|
||||
documentPaths, ok := typedSlice.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected type conversion error for document paths")
|
||||
}
|
||||
|
||||
// Create document references from paths
|
||||
docRefs := make([]*firestoreapi.DocumentRef, len(documentPaths))
|
||||
for i, path := range documentPaths {
|
||||
docRefs[i] = t.Client.Doc(path)
|
||||
}
|
||||
|
||||
// Get all documents
|
||||
snapshots, err := t.Client.GetAll(ctx, docRefs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get documents: %w", err)
|
||||
}
|
||||
|
||||
// Convert snapshots to response data
|
||||
results := make([]any, len(snapshots))
|
||||
for i, snapshot := range snapshots {
|
||||
docData := make(map[string]any)
|
||||
docData["path"] = documentPaths[i]
|
||||
docData["exists"] = snapshot.Exists()
|
||||
|
||||
if snapshot.Exists() {
|
||||
docData["data"] = snapshot.Data()
|
||||
docData["createTime"] = snapshot.CreateTime
|
||||
docData["updateTime"] = snapshot.UpdateTime
|
||||
docData["readTime"] = snapshot.ReadTime
|
||||
}
|
||||
|
||||
results[i] = docData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoregetdocuments_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetdocuments"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreGetDocuments(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
get_docs_tool:
|
||||
kind: firestore-get-documents
|
||||
source: my-firestore-instance
|
||||
description: Retrieve documents from Firestore by paths
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get_docs_tool": firestoregetdocuments.Config{
|
||||
Name: "get_docs_tool",
|
||||
Kind: "firestore-get-documents",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "Retrieve documents from Firestore by paths",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_get_docs:
|
||||
kind: firestore-get-documents
|
||||
source: prod-firestore
|
||||
description: Get documents with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- api-key-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_get_docs": firestoregetdocuments.Config{
|
||||
Name: "secure_get_docs",
|
||||
Kind: "firestore-get-documents",
|
||||
Source: "prod-firestore",
|
||||
Description: "Get documents with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "api-key-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
get_user_docs:
|
||||
kind: firestore-get-documents
|
||||
source: users-firestore
|
||||
description: Get user documents
|
||||
authRequired:
|
||||
- user-auth
|
||||
get_product_docs:
|
||||
kind: firestore-get-documents
|
||||
source: products-firestore
|
||||
description: Get product documents
|
||||
get_order_docs:
|
||||
kind: firestore-get-documents
|
||||
source: orders-firestore
|
||||
description: Get order documents
|
||||
authRequired:
|
||||
- user-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"get_user_docs": firestoregetdocuments.Config{
|
||||
Name: "get_user_docs",
|
||||
Kind: "firestore-get-documents",
|
||||
Source: "users-firestore",
|
||||
Description: "Get user documents",
|
||||
AuthRequired: []string{"user-auth"},
|
||||
},
|
||||
"get_product_docs": firestoregetdocuments.Config{
|
||||
Name: "get_product_docs",
|
||||
Kind: "firestore-get-documents",
|
||||
Source: "products-firestore",
|
||||
Description: "Get product documents",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"get_order_docs": firestoregetdocuments.Config{
|
||||
Name: "get_order_docs",
|
||||
Kind: "firestore-get-documents",
|
||||
Source: "orders-firestore",
|
||||
Description: "Get order documents",
|
||||
AuthRequired: []string{"user-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
159
internal/tools/firestore/firestoregetrules/firestoregetrules.go
Normal file
159
internal/tools/firestore/firestoregetrules/firestoregetrules.go
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoregetrules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
)
|
||||
|
||||
const kind string = "firestore-get-rules"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// No parameters needed for this tool
|
||||
parameters := tools.Parameters{}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
// Get the latest release for Firestore
|
||||
releaseName := fmt.Sprintf("projects/%s/releases/cloud.firestore", t.ProjectId)
|
||||
release, err := t.RulesClient.Projects.Releases.Get(releaseName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get latest Firestore release: %w", err)
|
||||
}
|
||||
|
||||
if release.RulesetName == "" {
|
||||
return nil, fmt.Errorf("no active Firestore rules were found in project '%s'", t.ProjectId)
|
||||
}
|
||||
|
||||
// Get the ruleset content
|
||||
ruleset, err := t.RulesClient.Projects.Rulesets.Get(release.RulesetName).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get ruleset content: %w", err)
|
||||
}
|
||||
|
||||
if ruleset.Source == nil || len(ruleset.Source.Files) == 0 {
|
||||
return nil, fmt.Errorf("no rules files found in ruleset")
|
||||
}
|
||||
|
||||
return ruleset, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestoregetrules_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestoregetrules"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreGetRules(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
get_rules_tool:
|
||||
kind: firestore-get-rules
|
||||
source: my-firestore-instance
|
||||
description: Retrieves the active Firestore security rules for the current project
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get_rules_tool": firestoregetrules.Config{
|
||||
Name: "get_rules_tool",
|
||||
Kind: "firestore-get-rules",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "Retrieves the active Firestore security rules for the current project",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_get_rules:
|
||||
kind: firestore-get-rules
|
||||
source: prod-firestore
|
||||
description: Get Firestore security rules with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- admin-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_get_rules": firestoregetrules.Config{
|
||||
Name: "secure_get_rules",
|
||||
Kind: "firestore-get-rules",
|
||||
Source: "prod-firestore",
|
||||
Description: "Get Firestore security rules with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "admin-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
get_dev_rules:
|
||||
kind: firestore-get-rules
|
||||
source: dev-firestore
|
||||
description: Get development Firestore rules
|
||||
authRequired:
|
||||
- dev-auth
|
||||
get_staging_rules:
|
||||
kind: firestore-get-rules
|
||||
source: staging-firestore
|
||||
description: Get staging Firestore rules
|
||||
get_prod_rules:
|
||||
kind: firestore-get-rules
|
||||
source: prod-firestore
|
||||
description: Get production Firestore rules
|
||||
authRequired:
|
||||
- prod-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"get_dev_rules": firestoregetrules.Config{
|
||||
Name: "get_dev_rules",
|
||||
Kind: "firestore-get-rules",
|
||||
Source: "dev-firestore",
|
||||
Description: "Get development Firestore rules",
|
||||
AuthRequired: []string{"dev-auth"},
|
||||
},
|
||||
"get_staging_rules": firestoregetrules.Config{
|
||||
Name: "get_staging_rules",
|
||||
Kind: "firestore-get-rules",
|
||||
Source: "staging-firestore",
|
||||
Description: "Get staging Firestore rules",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"get_prod_rules": firestoregetrules.Config{
|
||||
Name: "get_prod_rules",
|
||||
Kind: "firestore-get-rules",
|
||||
Source: "prod-firestore",
|
||||
Description: "Get production Firestore rules",
|
||||
AuthRequired: []string{"prod-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,175 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorelistcollections
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "firestore-list-collections"
|
||||
const parentPathKey string = "parentPath"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
emptyString := ""
|
||||
parentPathParameter := tools.NewStringParameterWithDefault(parentPathKey, emptyString, "Parent document path to list subcollections from. If not provided, lists root collections.")
|
||||
parameters := tools.Parameters{parentPathParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
|
||||
var collectionRefs []*firestoreapi.CollectionRef
|
||||
var err error
|
||||
|
||||
// Check if parentPath is provided
|
||||
parentPath, hasParent := mapParams[parentPathKey].(string)
|
||||
|
||||
if hasParent && parentPath != "" {
|
||||
// List subcollections of the specified document
|
||||
docRef := t.Client.Doc(parentPath)
|
||||
collectionRefs, err = docRef.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list subcollections of document %q: %w", parentPath, err)
|
||||
}
|
||||
} else {
|
||||
// List root collections
|
||||
collectionRefs, err = t.Client.Collections(ctx).GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list root collections: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert collection references to response data
|
||||
results := make([]any, len(collectionRefs))
|
||||
for i, collRef := range collectionRefs {
|
||||
collData := make(map[string]any)
|
||||
collData["id"] = collRef.ID
|
||||
collData["path"] = collRef.Path
|
||||
|
||||
// If this is a subcollection, include parent information
|
||||
if collRef.Parent != nil {
|
||||
collData["parent"] = collRef.Parent.Path
|
||||
}
|
||||
|
||||
results[i] = collData
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorelistcollections_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorelistcollections"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreListCollections(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
list_collections_tool:
|
||||
kind: firestore-list-collections
|
||||
source: my-firestore-instance
|
||||
description: List collections in Firestore
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list_collections_tool": firestorelistcollections.Config{
|
||||
Name: "list_collections_tool",
|
||||
Kind: "firestore-list-collections",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "List collections in Firestore",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_list_collections:
|
||||
kind: firestore-list-collections
|
||||
source: prod-firestore
|
||||
description: List collections with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- api-key-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_list_collections": firestorelistcollections.Config{
|
||||
Name: "secure_list_collections",
|
||||
Kind: "firestore-list-collections",
|
||||
Source: "prod-firestore",
|
||||
Description: "List collections with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "api-key-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
list_user_collections:
|
||||
kind: firestore-list-collections
|
||||
source: users-firestore
|
||||
description: List user-related collections
|
||||
authRequired:
|
||||
- user-auth
|
||||
list_product_collections:
|
||||
kind: firestore-list-collections
|
||||
source: products-firestore
|
||||
description: List product-related collections
|
||||
list_admin_collections:
|
||||
kind: firestore-list-collections
|
||||
source: admin-firestore
|
||||
description: List administrative collections
|
||||
authRequired:
|
||||
- user-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"list_user_collections": firestorelistcollections.Config{
|
||||
Name: "list_user_collections",
|
||||
Kind: "firestore-list-collections",
|
||||
Source: "users-firestore",
|
||||
Description: "List user-related collections",
|
||||
AuthRequired: []string{"user-auth"},
|
||||
},
|
||||
"list_product_collections": firestorelistcollections.Config{
|
||||
Name: "list_product_collections",
|
||||
Kind: "firestore-list-collections",
|
||||
Source: "products-firestore",
|
||||
Description: "List product-related collections",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"list_admin_collections": firestorelistcollections.Config{
|
||||
Name: "list_admin_collections",
|
||||
Kind: "firestore-list-collections",
|
||||
Source: "admin-firestore",
|
||||
Description: "List administrative collections",
|
||||
AuthRequired: []string{"user-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,529 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorequerycollection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
firestoreapi "cloud.google.com/go/firestore"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
// Constants for tool configuration
|
||||
const (
|
||||
kind = "firestore-query-collection"
|
||||
defaultLimit = 100
|
||||
defaultAnalyze = false
|
||||
maxFilterLength = 100 // Maximum filters to prevent abuse
|
||||
)
|
||||
|
||||
// Parameter keys
|
||||
const (
|
||||
collectionPathKey = "collectionPath"
|
||||
filtersKey = "filters"
|
||||
orderByKey = "orderBy"
|
||||
limitKey = "limit"
|
||||
analyzeQueryKey = "analyzeQuery"
|
||||
)
|
||||
|
||||
// Firestore operators
|
||||
var validOperators = map[string]bool{
|
||||
"<": true,
|
||||
"<=": true,
|
||||
">": true,
|
||||
">=": true,
|
||||
"==": true,
|
||||
"!=": true,
|
||||
"array-contains": true,
|
||||
"array-contains-any": true,
|
||||
"in": true,
|
||||
"not-in": true,
|
||||
}
|
||||
|
||||
// Error messages
|
||||
const (
|
||||
errMissingCollectionPath = "invalid or missing '%s' parameter"
|
||||
errInvalidFilters = "invalid '%s' parameter; expected an array"
|
||||
errFilterNotString = "filter at index %d is not a string"
|
||||
errFilterParseFailed = "failed to parse filter at index %d: %w"
|
||||
errInvalidOperator = "unsupported operator: %s. Valid operators are: %v"
|
||||
errMissingFilterValue = "no value specified for filter on field '%s'"
|
||||
errOrderByParseFailed = "failed to parse orderBy: %w"
|
||||
errQueryExecutionFailed = "failed to execute query: %w"
|
||||
errTooManyFilters = "too many filters provided: %d (maximum: %d)"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// compatibleSource defines the interface for sources that can provide a Firestore client
|
||||
type compatibleSource interface {
|
||||
FirestoreClient() *firestoreapi.Client
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
// Config represents the configuration for the Firestore query collection tool
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the kind of tool configuration
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize creates a new Tool instance from the configuration
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
parameters := createParameters()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.FirestoreClient(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// createParameters creates the parameter definitions for the tool
|
||||
func createParameters() tools.Parameters {
|
||||
collectionPathParameter := tools.NewStringParameter(
|
||||
collectionPathKey,
|
||||
"The path to the Firestore collection to query",
|
||||
)
|
||||
|
||||
filtersDescription := `Array of filter objects to apply to the query. Each filter is a JSON string with:
|
||||
- field: The field name to filter on
|
||||
- op: The operator to use ("<", "<=", ">", ">=", "==", "!=", "array-contains", "array-contains-any", "in", "not-in")
|
||||
- value: The value to compare against (can be string, number, boolean, or array)
|
||||
Example: {"field": "age", "op": ">", "value": 18}`
|
||||
|
||||
filtersParameter := tools.NewArrayParameter(
|
||||
filtersKey,
|
||||
filtersDescription,
|
||||
tools.NewStringParameter("item", "JSON string representation of a filter object"),
|
||||
)
|
||||
|
||||
orderByParameter := tools.NewStringParameter(
|
||||
orderByKey,
|
||||
"JSON string specifying the field and direction to order by (e.g., {\"field\": \"name\", \"direction\": \"ASCENDING\"}). Leave empty if not specified",
|
||||
)
|
||||
|
||||
limitParameter := tools.NewIntParameterWithDefault(
|
||||
limitKey,
|
||||
defaultLimit,
|
||||
"The maximum number of documents to return",
|
||||
)
|
||||
|
||||
analyzeQueryParameter := tools.NewBooleanParameterWithDefault(
|
||||
analyzeQueryKey,
|
||||
defaultAnalyze,
|
||||
"If true, returns query explain metrics including execution statistics",
|
||||
)
|
||||
|
||||
return tools.Parameters{
|
||||
collectionPathParameter,
|
||||
filtersParameter,
|
||||
orderByParameter,
|
||||
limitParameter,
|
||||
analyzeQueryParameter,
|
||||
}
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
// Tool represents the Firestore query collection tool
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *firestoreapi.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// FilterConfig represents a filter for the query
|
||||
type FilterConfig struct {
|
||||
Field string `json:"field"`
|
||||
Op string `json:"op"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
// Validate checks if the filter configuration is valid
|
||||
func (f *FilterConfig) Validate() error {
|
||||
if f.Field == "" {
|
||||
return fmt.Errorf("filter field cannot be empty")
|
||||
}
|
||||
|
||||
if !validOperators[f.Op] {
|
||||
ops := make([]string, 0, len(validOperators))
|
||||
for op := range validOperators {
|
||||
ops = append(ops, op)
|
||||
}
|
||||
return fmt.Errorf(errInvalidOperator, f.Op, ops)
|
||||
}
|
||||
|
||||
if f.Value == nil {
|
||||
return fmt.Errorf(errMissingFilterValue, f.Field)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OrderByConfig represents ordering configuration
|
||||
type OrderByConfig struct {
|
||||
Field string `json:"field"`
|
||||
Direction string `json:"direction"`
|
||||
}
|
||||
|
||||
// GetDirection returns the Firestore direction constant
|
||||
func (o *OrderByConfig) GetDirection() firestoreapi.Direction {
|
||||
if strings.EqualFold(o.Direction, "DESCENDING") {
|
||||
return firestoreapi.Desc
|
||||
}
|
||||
return firestoreapi.Asc
|
||||
}
|
||||
|
||||
// QueryResult represents a document result from the query
|
||||
type QueryResult struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
Data map[string]any `json:"data"`
|
||||
CreateTime interface{} `json:"createTime,omitempty"`
|
||||
UpdateTime interface{} `json:"updateTime,omitempty"`
|
||||
ReadTime interface{} `json:"readTime,omitempty"`
|
||||
}
|
||||
|
||||
// QueryResponse represents the full response including optional metrics
|
||||
type QueryResponse struct {
|
||||
Documents []QueryResult `json:"documents"`
|
||||
ExplainMetrics map[string]any `json:"explainMetrics,omitempty"`
|
||||
}
|
||||
|
||||
// Invoke executes the Firestore query based on the provided parameters
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
// Parse parameters
|
||||
queryParams, err := t.parseQueryParameters(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build the query
|
||||
query, err := t.buildQuery(queryParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Execute the query and return results
|
||||
return t.executeQuery(ctx, query, queryParams.AnalyzeQuery)
|
||||
}
|
||||
|
||||
// queryParameters holds all parsed query parameters
|
||||
type queryParameters struct {
|
||||
CollectionPath string
|
||||
Filters []FilterConfig
|
||||
OrderBy *OrderByConfig
|
||||
Limit int
|
||||
AnalyzeQuery bool
|
||||
}
|
||||
|
||||
// parseQueryParameters extracts and validates parameters from the input
|
||||
func (t Tool) parseQueryParameters(params tools.ParamValues) (*queryParameters, error) {
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get collection path
|
||||
collectionPath, ok := mapParams[collectionPathKey].(string)
|
||||
if !ok || collectionPath == "" {
|
||||
return nil, fmt.Errorf(errMissingCollectionPath, collectionPathKey)
|
||||
}
|
||||
|
||||
result := &queryParameters{
|
||||
CollectionPath: collectionPath,
|
||||
Limit: defaultLimit,
|
||||
AnalyzeQuery: defaultAnalyze,
|
||||
}
|
||||
|
||||
// Parse filters
|
||||
if filtersRaw, ok := mapParams[filtersKey]; ok && filtersRaw != nil {
|
||||
filters, err := t.parseFilters(filtersRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.Filters = filters
|
||||
}
|
||||
|
||||
// Parse orderBy
|
||||
if orderByRaw, ok := mapParams[orderByKey]; ok && orderByRaw != nil {
|
||||
orderBy, err := t.parseOrderBy(orderByRaw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result.OrderBy = orderBy
|
||||
}
|
||||
|
||||
// Parse limit
|
||||
if limit, ok := mapParams[limitKey].(int); ok {
|
||||
result.Limit = limit
|
||||
}
|
||||
|
||||
// Parse analyze
|
||||
if analyze, ok := mapParams[analyzeQueryKey].(bool); ok {
|
||||
result.AnalyzeQuery = analyze
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseFilters parses and validates filter configurations
|
||||
func (t Tool) parseFilters(filtersRaw interface{}) ([]FilterConfig, error) {
|
||||
filters, ok := filtersRaw.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(errInvalidFilters, filtersKey)
|
||||
}
|
||||
|
||||
if len(filters) > maxFilterLength {
|
||||
return nil, fmt.Errorf(errTooManyFilters, len(filters), maxFilterLength)
|
||||
}
|
||||
|
||||
result := make([]FilterConfig, 0, len(filters))
|
||||
for i, filterRaw := range filters {
|
||||
filterJSON, ok := filterRaw.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(errFilterNotString, i)
|
||||
}
|
||||
|
||||
var filter FilterConfig
|
||||
if err := json.Unmarshal([]byte(filterJSON), &filter); err != nil {
|
||||
return nil, fmt.Errorf(errFilterParseFailed, i, err)
|
||||
}
|
||||
|
||||
if err := filter.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("filter at index %d is invalid: %w", i, err)
|
||||
}
|
||||
|
||||
result = append(result, filter)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseOrderBy parses the orderBy configuration
|
||||
func (t Tool) parseOrderBy(orderByRaw interface{}) (*OrderByConfig, error) {
|
||||
orderByJSON, ok := orderByRaw.(string)
|
||||
if !ok || orderByJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var orderBy OrderByConfig
|
||||
if err := json.Unmarshal([]byte(orderByJSON), &orderBy); err != nil {
|
||||
return nil, fmt.Errorf(errOrderByParseFailed, err)
|
||||
}
|
||||
|
||||
if orderBy.Field == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &orderBy, nil
|
||||
}
|
||||
|
||||
// buildQuery constructs the Firestore query from parameters
|
||||
func (t Tool) buildQuery(params *queryParameters) (*firestoreapi.Query, error) {
|
||||
collection := t.Client.Collection(params.CollectionPath)
|
||||
query := collection.Query
|
||||
|
||||
// Apply filters
|
||||
if len(params.Filters) > 0 {
|
||||
filterConditions := make([]firestoreapi.EntityFilter, 0, len(params.Filters))
|
||||
for _, filter := range params.Filters {
|
||||
filterConditions = append(filterConditions, firestoreapi.PropertyFilter{
|
||||
Path: filter.Field,
|
||||
Operator: filter.Op,
|
||||
Value: filter.Value,
|
||||
})
|
||||
}
|
||||
|
||||
query = query.WhereEntity(firestoreapi.AndFilter{
|
||||
Filters: filterConditions,
|
||||
})
|
||||
}
|
||||
|
||||
// Apply ordering
|
||||
if params.OrderBy != nil {
|
||||
query = query.OrderBy(params.OrderBy.Field, params.OrderBy.GetDirection())
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
query = query.Limit(params.Limit)
|
||||
|
||||
// Apply analyze options
|
||||
if params.AnalyzeQuery {
|
||||
query = query.WithRunOptions(firestoreapi.ExplainOptions{
|
||||
Analyze: true,
|
||||
})
|
||||
}
|
||||
|
||||
return &query, nil
|
||||
}
|
||||
|
||||
// executeQuery runs the query and formats the results
|
||||
func (t Tool) executeQuery(ctx context.Context, query *firestoreapi.Query, analyzeQuery bool) (any, error) {
|
||||
docIterator := query.Documents(ctx)
|
||||
docs, err := docIterator.GetAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(errQueryExecutionFailed, err)
|
||||
}
|
||||
|
||||
// Convert results to structured format
|
||||
results := make([]QueryResult, len(docs))
|
||||
for i, doc := range docs {
|
||||
results[i] = QueryResult{
|
||||
ID: doc.Ref.ID,
|
||||
Path: doc.Ref.Path,
|
||||
Data: doc.Data(),
|
||||
CreateTime: doc.CreateTime,
|
||||
UpdateTime: doc.UpdateTime,
|
||||
ReadTime: doc.ReadTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Return with explain metrics if requested
|
||||
if analyzeQuery {
|
||||
explainMetrics, err := t.getExplainMetrics(docIterator)
|
||||
if err == nil && explainMetrics != nil {
|
||||
response := QueryResponse{
|
||||
Documents: results,
|
||||
ExplainMetrics: explainMetrics,
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return just the documents
|
||||
resultsAny := make([]any, len(results))
|
||||
for i, r := range results {
|
||||
resultsAny[i] = r
|
||||
}
|
||||
return resultsAny, nil
|
||||
}
|
||||
|
||||
// getExplainMetrics extracts explain metrics from the query iterator
|
||||
func (t Tool) getExplainMetrics(docIterator *firestoreapi.DocumentIterator) (map[string]any, error) {
|
||||
explainMetrics, err := docIterator.ExplainMetrics()
|
||||
if err != nil || explainMetrics == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
metricsData := make(map[string]any)
|
||||
|
||||
// Add plan summary if available
|
||||
if explainMetrics.PlanSummary != nil {
|
||||
planSummary := make(map[string]any)
|
||||
planSummary["indexesUsed"] = explainMetrics.PlanSummary.IndexesUsed
|
||||
metricsData["planSummary"] = planSummary
|
||||
}
|
||||
|
||||
// Add execution stats if available
|
||||
if explainMetrics.ExecutionStats != nil {
|
||||
executionStats := make(map[string]any)
|
||||
executionStats["resultsReturned"] = explainMetrics.ExecutionStats.ResultsReturned
|
||||
executionStats["readOperations"] = explainMetrics.ExecutionStats.ReadOperations
|
||||
|
||||
if explainMetrics.ExecutionStats.ExecutionDuration != nil {
|
||||
executionStats["executionDuration"] = explainMetrics.ExecutionStats.ExecutionDuration.String()
|
||||
}
|
||||
|
||||
if explainMetrics.ExecutionStats.DebugStats != nil {
|
||||
executionStats["debugStats"] = *explainMetrics.ExecutionStats.DebugStats
|
||||
}
|
||||
|
||||
metricsData["executionStats"] = executionStats
|
||||
}
|
||||
|
||||
return metricsData, nil
|
||||
}
|
||||
|
||||
// ParseParams parses and validates input parameters
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
// Manifest returns the tool manifest
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized based on verified auth services
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorequerycollection_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorequerycollection"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreQueryCollection(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
query_users_tool:
|
||||
kind: firestore-query-collection
|
||||
source: my-firestore-instance
|
||||
description: Query users collection with filters and ordering
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"query_users_tool": firestorequerycollection.Config{
|
||||
Name: "query_users_tool",
|
||||
Kind: "firestore-query-collection",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "Query users collection with filters and ordering",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_query_tool:
|
||||
kind: firestore-query-collection
|
||||
source: prod-firestore
|
||||
description: Query collections with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- api-key-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_query_tool": firestorequerycollection.Config{
|
||||
Name: "secure_query_tool",
|
||||
Kind: "firestore-query-collection",
|
||||
Source: "prod-firestore",
|
||||
Description: "Query collections with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "api-key-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
query_users:
|
||||
kind: firestore-query-collection
|
||||
source: users-firestore
|
||||
description: Query user documents with filtering
|
||||
authRequired:
|
||||
- user-auth
|
||||
query_products:
|
||||
kind: firestore-query-collection
|
||||
source: products-firestore
|
||||
description: Query product catalog
|
||||
query_orders:
|
||||
kind: firestore-query-collection
|
||||
source: orders-firestore
|
||||
description: Query customer orders with complex filters
|
||||
authRequired:
|
||||
- user-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"query_users": firestorequerycollection.Config{
|
||||
Name: "query_users",
|
||||
Kind: "firestore-query-collection",
|
||||
Source: "users-firestore",
|
||||
Description: "Query user documents with filtering",
|
||||
AuthRequired: []string{"user-auth"},
|
||||
},
|
||||
"query_products": firestorequerycollection.Config{
|
||||
Name: "query_products",
|
||||
Kind: "firestore-query-collection",
|
||||
Source: "products-firestore",
|
||||
Description: "Query product catalog",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"query_orders": firestorequerycollection.Config{
|
||||
Name: "query_orders",
|
||||
Kind: "firestore-query-collection",
|
||||
Source: "orders-firestore",
|
||||
Description: "Query customer orders with complex filters",
|
||||
AuthRequired: []string{"user-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorevalidaterules
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
firestoreds "github.com/googleapis/genai-toolbox/internal/sources/firestore"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/firebaserules/v1"
|
||||
)
|
||||
|
||||
const kind string = "firestore-validate-rules"
|
||||
|
||||
// Parameter keys
|
||||
const (
|
||||
sourceKey = "source"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
FirebaseRulesClient() *firebaserules.Service
|
||||
GetProjectId() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &firestoreds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{firestoreds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
// Create parameters
|
||||
parameters := createParameters()
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
RulesClient: s.FirebaseRulesClient(),
|
||||
ProjectId: s.GetProjectId(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// createParameters creates the parameter definitions for the tool
|
||||
func createParameters() tools.Parameters {
|
||||
sourceParameter := tools.NewStringParameter(
|
||||
sourceKey,
|
||||
"The Firestore Rules source code to validate",
|
||||
)
|
||||
|
||||
return tools.Parameters{sourceParameter}
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
RulesClient *firebaserules.Service
|
||||
ProjectId string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Issue represents a validation issue in the rules
|
||||
type Issue struct {
|
||||
SourcePosition SourcePosition `json:"sourcePosition"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
}
|
||||
|
||||
// SourcePosition represents the location of an issue in the source
|
||||
type SourcePosition struct {
|
||||
FileName string `json:"fileName,omitempty"`
|
||||
Line int64 `json:"line"` // 1-based
|
||||
Column int64 `json:"column"` // 1-based
|
||||
CurrentOffset int64 `json:"currentOffset"` // 0-based, inclusive start
|
||||
EndOffset int64 `json:"endOffset"` // 0-based, exclusive end
|
||||
}
|
||||
|
||||
// ValidationResult represents the result of rules validation
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
IssueCount int `json:"issueCount"`
|
||||
FormattedIssues string `json:"formattedIssues,omitempty"`
|
||||
RawIssues []Issue `json:"rawIssues,omitempty"`
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
mapParams := params.AsMap()
|
||||
|
||||
// Get source parameter
|
||||
source, ok := mapParams[sourceKey].(string)
|
||||
if !ok || source == "" {
|
||||
return nil, fmt.Errorf("invalid or missing '%s' parameter", sourceKey)
|
||||
}
|
||||
|
||||
// Create test request
|
||||
testRequest := &firebaserules.TestRulesetRequest{
|
||||
Source: &firebaserules.Source{
|
||||
Files: []*firebaserules.File{
|
||||
{
|
||||
Name: "firestore.rules",
|
||||
Content: source,
|
||||
},
|
||||
},
|
||||
},
|
||||
// We don't need test cases for validation only
|
||||
TestSuite: &firebaserules.TestSuite{
|
||||
TestCases: []*firebaserules.TestCase{},
|
||||
},
|
||||
}
|
||||
|
||||
// Call the test API
|
||||
projectName := fmt.Sprintf("projects/%s", t.ProjectId)
|
||||
response, err := t.RulesClient.Projects.Test(projectName, testRequest).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate rules: %w", err)
|
||||
}
|
||||
|
||||
// Process the response
|
||||
result := t.processValidationResponse(response, source)
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) processValidationResponse(response *firebaserules.TestRulesetResponse, source string) ValidationResult {
|
||||
if len(response.Issues) == 0 {
|
||||
return ValidationResult{
|
||||
Valid: true,
|
||||
IssueCount: 0,
|
||||
FormattedIssues: "✓ No errors detected. Rules are valid.",
|
||||
}
|
||||
}
|
||||
|
||||
// Convert issues to our format
|
||||
issues := make([]Issue, len(response.Issues))
|
||||
for i, issue := range response.Issues {
|
||||
issues[i] = Issue{
|
||||
Description: issue.Description,
|
||||
Severity: issue.Severity,
|
||||
SourcePosition: SourcePosition{
|
||||
FileName: issue.SourcePosition.FileName,
|
||||
Line: issue.SourcePosition.Line,
|
||||
Column: issue.SourcePosition.Column,
|
||||
CurrentOffset: issue.SourcePosition.CurrentOffset,
|
||||
EndOffset: issue.SourcePosition.EndOffset,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Format issues
|
||||
formattedIssues := t.formatRulesetIssues(issues, source)
|
||||
|
||||
return ValidationResult{
|
||||
Valid: false,
|
||||
IssueCount: len(issues),
|
||||
FormattedIssues: formattedIssues,
|
||||
RawIssues: issues,
|
||||
}
|
||||
}
|
||||
|
||||
// formatRulesetIssues formats validation issues into a human-readable string with code snippets
|
||||
func (t Tool) formatRulesetIssues(issues []Issue, rulesSource string) string {
|
||||
sourceLines := strings.Split(rulesSource, "\n")
|
||||
var formattedOutput []string
|
||||
|
||||
formattedOutput = append(formattedOutput, fmt.Sprintf("Found %d issue(s) in rules source:\n", len(issues)))
|
||||
|
||||
for _, issue := range issues {
|
||||
issueString := fmt.Sprintf("%s: %s [Ln %d, Col %d]",
|
||||
issue.Severity,
|
||||
issue.Description,
|
||||
issue.SourcePosition.Line,
|
||||
issue.SourcePosition.Column)
|
||||
|
||||
if issue.SourcePosition.Line > 0 {
|
||||
lineIndex := int(issue.SourcePosition.Line - 1) // 0-based index
|
||||
if lineIndex >= 0 && lineIndex < len(sourceLines) {
|
||||
errorLine := sourceLines[lineIndex]
|
||||
issueString += fmt.Sprintf("\n```\n%s", errorLine)
|
||||
|
||||
// Add carets if we have column and offset information
|
||||
if issue.SourcePosition.Column > 0 &&
|
||||
issue.SourcePosition.CurrentOffset >= 0 &&
|
||||
issue.SourcePosition.EndOffset > issue.SourcePosition.CurrentOffset {
|
||||
|
||||
startColumn := int(issue.SourcePosition.Column - 1) // 0-based
|
||||
errorTokenLength := int(issue.SourcePosition.EndOffset - issue.SourcePosition.CurrentOffset)
|
||||
|
||||
if startColumn >= 0 && errorTokenLength > 0 && startColumn <= len(errorLine) {
|
||||
padding := strings.Repeat(" ", startColumn)
|
||||
carets := strings.Repeat("^", errorTokenLength)
|
||||
issueString += fmt.Sprintf("\n%s%s", padding, carets)
|
||||
}
|
||||
}
|
||||
issueString += "\n```"
|
||||
}
|
||||
}
|
||||
|
||||
formattedOutput = append(formattedOutput, issueString)
|
||||
}
|
||||
|
||||
return strings.Join(formattedOutput, "\n\n")
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package firestorevalidaterules_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/firestore/firestorevalidaterules"
|
||||
)
|
||||
|
||||
func TestParseFromYamlFirestoreValidateRules(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
validate_rules_tool:
|
||||
kind: firestore-validate-rules
|
||||
source: my-firestore-instance
|
||||
description: Validate Firestore security rules
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"validate_rules_tool": firestorevalidaterules.Config{
|
||||
Name: "validate_rules_tool",
|
||||
Kind: "firestore-validate-rules",
|
||||
Source: "my-firestore-instance",
|
||||
Description: "Validate Firestore security rules",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth requirements",
|
||||
in: `
|
||||
tools:
|
||||
secure_validate_rules:
|
||||
kind: firestore-validate-rules
|
||||
source: prod-firestore
|
||||
description: Validate rules with authentication
|
||||
authRequired:
|
||||
- google-auth-service
|
||||
- api-key-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"secure_validate_rules": firestorevalidaterules.Config{
|
||||
Name: "secure_validate_rules",
|
||||
Kind: "firestore-validate-rules",
|
||||
Source: "prod-firestore",
|
||||
Description: "Validate rules with authentication",
|
||||
AuthRequired: []string{"google-auth-service", "api-key-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlMultipleTools(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
in := `
|
||||
tools:
|
||||
validate_dev_rules:
|
||||
kind: firestore-validate-rules
|
||||
source: dev-firestore
|
||||
description: Validate development environment rules
|
||||
authRequired:
|
||||
- dev-auth
|
||||
validate_staging_rules:
|
||||
kind: firestore-validate-rules
|
||||
source: staging-firestore
|
||||
description: Validate staging environment rules
|
||||
validate_prod_rules:
|
||||
kind: firestore-validate-rules
|
||||
source: prod-firestore
|
||||
description: Validate production environment rules
|
||||
authRequired:
|
||||
- prod-auth
|
||||
- admin-auth
|
||||
`
|
||||
want := server.ToolConfigs{
|
||||
"validate_dev_rules": firestorevalidaterules.Config{
|
||||
Name: "validate_dev_rules",
|
||||
Kind: "firestore-validate-rules",
|
||||
Source: "dev-firestore",
|
||||
Description: "Validate development environment rules",
|
||||
AuthRequired: []string{"dev-auth"},
|
||||
},
|
||||
"validate_staging_rules": firestorevalidaterules.Config{
|
||||
Name: "validate_staging_rules",
|
||||
Kind: "firestore-validate-rules",
|
||||
Source: "staging-firestore",
|
||||
Description: "Validate staging environment rules",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
"validate_prod_rules": firestorevalidaterules.Config{
|
||||
Name: "validate_prod_rules",
|
||||
Kind: "firestore-validate-rules",
|
||||
Source: "prod-firestore",
|
||||
Description: "Validate production environment rules",
|
||||
AuthRequired: []string{"prod-auth", "admin-auth"},
|
||||
},
|
||||
}
|
||||
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
}
|
||||
@@ -303,7 +303,7 @@ func getHeaders(headerParams tools.Parameters, defaultHeaders map[string]string,
|
||||
return allHeaders, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
// Calculate request body
|
||||
@@ -349,15 +349,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
var data any
|
||||
if err = json.Unmarshal(body, &data); err != nil {
|
||||
// if unable to unmarshal data, return result as string.
|
||||
return []any{string(body)}, nil
|
||||
return string(body), nil
|
||||
}
|
||||
// if data is a list, return as is.
|
||||
dataList, ok := data.([]any)
|
||||
if ok {
|
||||
return dataList, nil
|
||||
}
|
||||
// if data is not a list (e.g. single map), return data in list.
|
||||
return []any{data}, nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -116,7 +116,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -125,7 +125,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -116,7 +116,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -124,7 +124,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4j
|
||||
package neo4jcypher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/goccy/go-yaml"
|
||||
neo4jsc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
"github.com/neo4j/neo4j-go-driver/v5/neo4j"
|
||||
|
||||
@@ -119,7 +119,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
config := neo4j.ExecuteQueryWithDatabase(t.Database)
|
||||
@@ -12,17 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package neo4j_test
|
||||
package neo4jcypher
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
)
|
||||
|
||||
func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
@@ -54,7 +53,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
description: country parameter description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": neo4j.Config{
|
||||
"example_tool": Config{
|
||||
Name: "example_tool",
|
||||
Kind: "neo4j-cypher",
|
||||
Source: "my-neo4j-instance",
|
||||
@@ -74,7 +73,7 @@ func TestParseFromYamlNeo4j(t *testing.T) {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
err = yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
@@ -32,6 +32,7 @@ const (
|
||||
typeFloat = "float"
|
||||
typeBool = "boolean"
|
||||
typeArray = "array"
|
||||
typeMap = "map"
|
||||
)
|
||||
|
||||
// ParamValues is an ordered list of ParamValue
|
||||
@@ -367,6 +368,17 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeMap:
|
||||
a := &MapParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%q is not valid type for a parameter", t)
|
||||
}
|
||||
@@ -401,19 +413,21 @@ func (ps Parameters) McpManifest() McpToolsSchema {
|
||||
|
||||
// ParameterManifest represents parameters when served as part of a ToolManifest.
|
||||
type ParameterManifest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
AuthServices []string `json:"authSources"`
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description"`
|
||||
AuthServices []string `json:"authSources"`
|
||||
Items *ParameterManifest `json:"items,omitempty"`
|
||||
AdditionalProperties any `json:"AdditionalProperties,omitempty"`
|
||||
}
|
||||
|
||||
// ParameterMcpManifest represents properties when served as part of a ToolMcpManifest.
|
||||
type ParameterMcpManifest struct {
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Items *ParameterMcpManifest `json:"items,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Items *ParameterMcpManifest `json:"items,omitempty"`
|
||||
AdditionalProperties any `json:"AdditionalProperties,omitempty"`
|
||||
}
|
||||
|
||||
// CommonParameter are default fields that are emebdding in most Parameter implementations. Embedding this stuct will give the object Name() and Type() functions.
|
||||
@@ -1022,3 +1036,211 @@ func (p *ArrayParameter) McpManifest() ParameterMcpManifest {
|
||||
Items: &items,
|
||||
}
|
||||
}
|
||||
|
||||
// MapParameter is a parameter representing a map with string keys. If ValueType is
|
||||
// specified (e.g., "string"), values are validated against that type. If ValueType
|
||||
// is empty, it is treated as a generic map[string]any.
|
||||
type MapParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *map[string]any `yaml:"default,omitempty"`
|
||||
ValueType string `yaml:"valueType,omitempty"`
|
||||
}
|
||||
|
||||
// Ensure MapParameter implements the Parameter interface.
|
||||
var _ Parameter = &MapParameter{}
|
||||
|
||||
// NewMapParameter is a convenience function for initializing a MapParameter.
|
||||
func NewMapParameter(name string, desc string, valueType string) *MapParameter {
|
||||
return &MapParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: "map",
|
||||
Desc: desc,
|
||||
},
|
||||
ValueType: valueType,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapParameterWithDefault is a convenience function for initializing a MapParameter with a default value.
|
||||
func NewMapParameterWithDefault(name string, defaultV map[string]any, desc string, valueType string) *MapParameter {
|
||||
return &MapParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: "map",
|
||||
Desc: desc,
|
||||
},
|
||||
ValueType: valueType,
|
||||
Default: &defaultV,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapParameterWithRequired is a convenience function for initializing a MapParameter as required.
|
||||
func NewMapParameterWithRequired(name string, desc string, required bool, valueType string) *MapParameter {
|
||||
return &MapParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: "map",
|
||||
Desc: desc,
|
||||
Required: &required,
|
||||
},
|
||||
ValueType: valueType,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapParameterWithAuth is a convenience function for initializing a MapParameter with auth services.
|
||||
func NewMapParameterWithAuth(name string, desc string, valueType string, authServices []ParamAuthService) *MapParameter {
|
||||
return &MapParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: name,
|
||||
Type: "map",
|
||||
Desc: desc,
|
||||
AuthServices: authServices,
|
||||
},
|
||||
ValueType: valueType,
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalYAML handles parsing the MapParameter from YAML input.
|
||||
func (p *MapParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
var rawItem struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *map[string]any `yaml:"default"`
|
||||
ValueType string `yaml:"valueType"`
|
||||
}
|
||||
if err := unmarshal(&rawItem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Validate `ValueType` to be one of the supported basic types
|
||||
if rawItem.ValueType != "" {
|
||||
if _, err := getPrototypeParameter(rawItem.ValueType); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
p.CommonParameter = rawItem.CommonParameter
|
||||
p.Default = rawItem.Default
|
||||
p.ValueType = rawItem.ValueType
|
||||
return nil
|
||||
}
|
||||
|
||||
// getPrototypeParameter is a helper factory to create a temporary parameter
|
||||
// based on a type string for parsing and manifest generation.
|
||||
func getPrototypeParameter(typeName string) (Parameter, error) {
|
||||
switch typeName {
|
||||
case "string":
|
||||
return NewStringParameter("", ""), nil
|
||||
case "integer":
|
||||
return NewIntParameter("", ""), nil
|
||||
case "boolean":
|
||||
return NewBooleanParameter("", ""), nil
|
||||
case "float":
|
||||
return NewFloatParameter("", ""), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported valueType %q for map parameter", typeName)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse validates and parses an incoming value for the map parameter.
|
||||
func (p *MapParameter) Parse(v any) (any, error) {
|
||||
m, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
return nil, &ParseTypeError{p.Name, p.Type, m}
|
||||
}
|
||||
// for generic maps, convert json.Numbers to their corresponding types
|
||||
if p.ValueType == "" {
|
||||
convertedData, err := util.ConvertNumbers(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse integer or float values in map: %s", err)
|
||||
}
|
||||
convertedMap, ok := convertedData.(map[string]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("internal error: ConvertNumbers should return a map, but got type %T", convertedData)
|
||||
}
|
||||
return convertedMap, nil
|
||||
}
|
||||
|
||||
// Otherwise, get a prototype and parse each value in the map.
|
||||
prototype, err := getPrototypeParameter(p.ValueType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rtn := make(map[string]any, len(m))
|
||||
for key, val := range m {
|
||||
parsedVal, err := prototype.Parse(val)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse value for key %q: %w", key, err)
|
||||
}
|
||||
rtn[key] = parsedVal
|
||||
}
|
||||
return rtn, nil
|
||||
}
|
||||
|
||||
func (p *MapParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *MapParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
func (p *MapParameter) GetValueType() string {
|
||||
return p.ValueType
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the MapParameter.
|
||||
func (p *MapParameter) Manifest() ParameterManifest {
|
||||
authNames := make([]string, len(p.AuthServices))
|
||||
for i, a := range p.AuthServices {
|
||||
authNames[i] = a.Name
|
||||
}
|
||||
r := CheckParamRequired(p.GetRequired(), p.GetDefault())
|
||||
|
||||
var additionalProperties any
|
||||
if p.ValueType != "" {
|
||||
prototype, err := getPrototypeParameter(p.ValueType)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
valueSchema := prototype.Manifest()
|
||||
additionalProperties = &valueSchema
|
||||
} else {
|
||||
// If no valueType is given, allow any properties.
|
||||
additionalProperties = true
|
||||
}
|
||||
|
||||
return ParameterManifest{
|
||||
Name: p.Name,
|
||||
Type: "object",
|
||||
Required: r,
|
||||
Description: p.Desc,
|
||||
AuthServices: authNames,
|
||||
AdditionalProperties: additionalProperties,
|
||||
}
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest for the MapParameter.
|
||||
func (p *MapParameter) McpManifest() ParameterMcpManifest {
|
||||
var additionalProperties any
|
||||
if p.ValueType != "" {
|
||||
prototype, err := getPrototypeParameter(p.ValueType)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
valueSchema := prototype.McpManifest()
|
||||
additionalProperties = &valueSchema
|
||||
} else {
|
||||
// If no valueType is given, allow any properties.
|
||||
additionalProperties = true
|
||||
}
|
||||
|
||||
return ParameterMcpManifest{
|
||||
Type: "object",
|
||||
Description: p.Desc,
|
||||
AdditionalProperties: additionalProperties,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,10 +18,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -294,6 +294,63 @@ func TestParametersMarshal(t *testing.T) {
|
||||
tools.NewArrayParameterWithDefault("my_array", []any{1.0, 1.1}, "this param is an array of floats", tools.NewFloatParameter("my_float", "float item")),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map with string values",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_map",
|
||||
"type": "map",
|
||||
"description": "this param is a map of strings",
|
||||
"valueType": "string",
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewMapParameter("my_map", "this param is a map of strings", "string"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map not required",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_map",
|
||||
"type": "map",
|
||||
"description": "this param is a map of strings",
|
||||
"required": false,
|
||||
"valueType": "string",
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewMapParameterWithRequired("my_map", "this param is a map of strings", false, "string"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map with default",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_map",
|
||||
"type": "map",
|
||||
"description": "this param is a map of strings",
|
||||
"default": map[string]any{"key1": "val1"},
|
||||
"valueType": "string",
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewMapParameterWithDefault("my_map", map[string]any{"key1": "val1"}, "this param is a map of strings", "string"),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generic map (no valueType)",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_generic_map",
|
||||
"type": "map",
|
||||
"description": "this param is a generic map",
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewMapParameter("my_generic_map", "this param is a generic map", ""),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -350,13 +407,13 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with authSources",
|
||||
name: "string with authServices",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_string",
|
||||
"type": "string",
|
||||
"description": "this param is a string",
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -396,13 +453,13 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int with authSources",
|
||||
name: "int with authServices",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_integer",
|
||||
"type": "integer",
|
||||
"description": "this param is an int",
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -442,13 +499,13 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "float with authSources",
|
||||
name: "float with authServices",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_float",
|
||||
"type": "float",
|
||||
"description": "my param is a float",
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -488,13 +545,13 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bool with authSources",
|
||||
name: "bool with authServices",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_bool",
|
||||
"type": "boolean",
|
||||
"description": "this param is a boolean",
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -539,7 +596,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string array with authSources",
|
||||
name: "string array with authServices",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_array",
|
||||
@@ -550,7 +607,7 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
"type": "string",
|
||||
"description": "string item",
|
||||
},
|
||||
"authSources": []map[string]string{
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
@@ -594,6 +651,24 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
tools.NewArrayParameterWithAuth("my_array", "this param is an array of floats", tools.NewFloatParameter("my_float", "float item"), authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_map",
|
||||
"type": "map",
|
||||
"description": "this param is a map of strings",
|
||||
"valueType": "string",
|
||||
"authServices": []map[string]string{
|
||||
{"name": "my-google-auth-service", "field": "user_id"},
|
||||
{"name": "other-auth-service", "field": "user_id"},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewMapParameterWithAuth("my_map", "this param is a map of strings", "string", authServices),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -622,6 +697,7 @@ func TestParametersParse(t *testing.T) {
|
||||
in map[string]any
|
||||
want tools.ParamValues
|
||||
}{
|
||||
// ... (primitive type tests are unchanged)
|
||||
{
|
||||
name: "string",
|
||||
params: tools.Parameters{
|
||||
@@ -780,6 +856,51 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: nil}},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameter("my_map", "a map", "string"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_map": map[string]any{"key1": "val1", "key2": "val2"},
|
||||
},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_map", Value: map[string]any{"key1": "val1", "key2": "val2"}}},
|
||||
},
|
||||
{
|
||||
name: "generic map",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameter("my_map_generic_type", "a generic map", ""),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_map_generic_type": map[string]any{"key1": "val1", "key2": 123, "key3": true},
|
||||
},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_map_generic_type", Value: map[string]any{"key1": "val1", "key2": int64(123), "key3": true}}},
|
||||
},
|
||||
{
|
||||
name: "not map (value type mismatch)",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameter("my_map", "a map", "string"),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_map": map[string]any{"key1": 123},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map default",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameterWithDefault("my_map_default", map[string]any{"default_key": "default_val"}, "a map", "string"),
|
||||
},
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_map_default", Value: map[string]any{"default_key": "default_val"}}},
|
||||
},
|
||||
{
|
||||
name: "map not required",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameterWithRequired("my_map_not_required", "a map", false, "string"),
|
||||
},
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_map_not_required", Value: nil}},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -809,15 +930,10 @@ func TestParametersParse(t *testing.T) {
|
||||
if wantErr {
|
||||
t.Fatalf("expected error but Param parsed successfully: %s", gotAll)
|
||||
}
|
||||
for i, got := range gotAll {
|
||||
want := tc.want[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
gotType, wantType := reflect.TypeOf(got), reflect.TypeOf(want)
|
||||
if gotType != wantType {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Use cmp.Diff for robust comparison
|
||||
if diff := cmp.Diff(tc.want, gotAll); diff != "" {
|
||||
t.Fatalf("ParseParams() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -945,6 +1061,15 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
},
|
||||
claimsMap: map[string]map[string]any{"my-google-auth-service": {"not_an_auth_field": "Alice"}},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
params: tools.Parameters{
|
||||
tools.NewMapParameterWithAuth("my_map", "a map", "string", authServices),
|
||||
},
|
||||
in: map[string]any{"my_map": map[string]any{"key1": "val1"}},
|
||||
claimsMap: map[string]map[string]any{"my-google-auth-service": {"auth_field": map[string]any{"authed_key": "authed_val"}}},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_map", Value: map[string]any{"authed_key": "authed_val"}}},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -970,15 +1095,9 @@ func TestAuthParametersParse(t *testing.T) {
|
||||
}
|
||||
t.Fatalf("unexpected error from ParseParams: %s", err)
|
||||
}
|
||||
for i, got := range gotAll {
|
||||
want := tc.want[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
gotType, wantType := reflect.TypeOf(got), reflect.TypeOf(want)
|
||||
if gotType != wantType {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tc.want, gotAll); diff != "" {
|
||||
t.Fatalf("ParseParams() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1142,12 +1261,48 @@ func TestParamManifest(t *testing.T) {
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map with string values",
|
||||
in: tools.NewMapParameter("foo-map", "bar", "string"),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-map",
|
||||
Type: "object",
|
||||
Required: true,
|
||||
Description: "bar",
|
||||
AuthServices: []string{},
|
||||
AdditionalProperties: &tools.ParameterManifest{Name: "", Type: "string", Required: true, Description: "", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map not required",
|
||||
in: tools.NewMapParameterWithRequired("foo-map", "bar", false, "string"),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-map",
|
||||
Type: "object",
|
||||
Required: false,
|
||||
Description: "bar",
|
||||
AuthServices: []string{},
|
||||
AdditionalProperties: &tools.ParameterManifest{Name: "", Type: "string", Required: true, Description: "", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generic map (additionalProperties true)",
|
||||
in: tools.NewMapParameter("foo-map", "bar", ""),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-map",
|
||||
Type: "object",
|
||||
Required: true,
|
||||
Description: "bar",
|
||||
AuthServices: []string{},
|
||||
AdditionalProperties: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.in.Manifest()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected manifest: got %+v, want %+v", got, tc.want)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("unexpected manifest (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1188,12 +1343,31 @@ func TestParamMcpManifest(t *testing.T) {
|
||||
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "map with string values",
|
||||
in: tools.NewMapParameter("foo-map", "bar", "string"),
|
||||
want: tools.ParameterMcpManifest{
|
||||
Type: "object",
|
||||
Description: "bar",
|
||||
AdditionalProperties: &tools.ParameterMcpManifest{Type: "string", Description: ""},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "generic map (additionalProperties true)",
|
||||
in: tools.NewMapParameter("foo-map", "bar", ""),
|
||||
want: tools.ParameterMcpManifest{
|
||||
Type: "object",
|
||||
Description: "bar",
|
||||
AdditionalProperties: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.in.McpManifest()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected manifest: got %+v, want %+v", got, tc.want)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("unexpected manifest (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1206,46 +1380,46 @@ func TestMcpManifest(t *testing.T) {
|
||||
want tools.McpToolsSchema
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
name: "all types",
|
||||
in: tools.Parameters{
|
||||
tools.NewStringParameterWithDefault("foo-string", "foo", "bar"),
|
||||
tools.NewStringParameter("foo-string2", "bar"),
|
||||
tools.NewStringParameterWithRequired("foo-string-req", "bar", true),
|
||||
tools.NewStringParameterWithRequired("foo-string-not-req", "bar", false),
|
||||
tools.NewIntParameterWithDefault("foo-int", 1, "bar"),
|
||||
tools.NewIntParameter("foo-int2", "bar"),
|
||||
tools.NewArrayParameterWithDefault("foo-array", []any{"hello", "world"}, "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
tools.NewMapParameter("foo-map-int", "a map of ints", "integer"),
|
||||
tools.NewMapParameter("foo-map-any", "a map of any", ""),
|
||||
},
|
||||
want: tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
Properties: map[string]tools.ParameterMcpManifest{
|
||||
"foo-string": tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
"foo-string2": tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
"foo-string-req": tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
"foo-string-not-req": tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
"foo-int": tools.ParameterMcpManifest{Type: "integer", Description: "bar"},
|
||||
"foo-int2": tools.ParameterMcpManifest{Type: "integer", Description: "bar"},
|
||||
"foo-array": tools.ParameterMcpManifest{
|
||||
"foo-string": {Type: "string", Description: "bar"},
|
||||
"foo-string2": {Type: "string", Description: "bar"},
|
||||
"foo-int2": {Type: "integer", Description: "bar"},
|
||||
"foo-array2": {
|
||||
Type: "array",
|
||||
Description: "bar",
|
||||
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
},
|
||||
"foo-array2": tools.ParameterMcpManifest{
|
||||
Type: "array",
|
||||
Description: "bar",
|
||||
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
"foo-map-int": {
|
||||
Type: "object",
|
||||
Description: "a map of ints",
|
||||
AdditionalProperties: &tools.ParameterMcpManifest{Type: "integer", Description: ""},
|
||||
},
|
||||
"foo-map-any": {
|
||||
Type: "object",
|
||||
Description: "a map of any",
|
||||
AdditionalProperties: true,
|
||||
},
|
||||
},
|
||||
Required: []string{"foo-string2", "foo-string-req", "foo-int2", "foo-array2"},
|
||||
Required: []string{"foo-string2", "foo-int2", "foo-array2", "foo-map-int", "foo-map-any"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tc.in.McpManifest()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("unexpected manifest: got %+v, want %+v", got, tc.want)
|
||||
if diff := cmp.Diff(tc.want, got); diff != "" {
|
||||
t.Fatalf("unexpected manifest (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1317,6 +1491,19 @@ func TestFailParametersUnmarshal(t *testing.T) {
|
||||
},
|
||||
err: "unable to parse as \"array\": unable to parse 'items' field: unable to parse as \"string\": Key: 'CommonParameter.Name' Error:Field validation for 'Name' failed on the 'required' tag",
|
||||
},
|
||||
// --- MODIFIED MAP PARAMETER TEST ---
|
||||
{
|
||||
name: "map with invalid valueType",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "my_map",
|
||||
"type": "map",
|
||||
"description": "this param is a map",
|
||||
"valueType": "not-a-real-type",
|
||||
},
|
||||
},
|
||||
err: "unsupported valueType \"not-a-real-type\" for map parameter",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -1332,14 +1519,18 @@ func TestFailParametersUnmarshal(t *testing.T) {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
|
||||
if !strings.Contains(errStr, tc.err) {
|
||||
t.Fatalf("unexpected error: got %q, want to contain %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ... (Remaining test functions do not involve parameter definitions and need no changes)
|
||||
|
||||
func TestConvertArrayParamToString(t *testing.T) {
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
in []any
|
||||
@@ -1482,6 +1673,7 @@ func TestGetParams(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFailGetParams(t *testing.T) {
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
params tools.Parameters
|
||||
|
||||
@@ -118,7 +118,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -126,7 +126,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -115,7 +115,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
cmds, err := replaceCommandsParams(t.Commands, t.Parameters, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error replacing commands' parameters: %s", err)
|
||||
|
||||
@@ -144,7 +144,7 @@ func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
|
||||
@@ -164,7 +164,7 @@ func processRows(iter *spanner.RowIterator) ([]any, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -122,7 +122,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
|
||||
@@ -64,7 +64,7 @@ type ToolConfig interface {
|
||||
}
|
||||
|
||||
type Tool interface {
|
||||
Invoke(context.Context, ParamValues) ([]any, error)
|
||||
Invoke(context.Context, ParamValues) (any, error)
|
||||
ParseParams(map[string]any, map[string]map[string]any) (ParamValues, error)
|
||||
Manifest() Manifest
|
||||
McpManifest() McpManifest
|
||||
|
||||
@@ -85,7 +85,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
durationStr, ok := paramsMap["duration"].(string)
|
||||
@@ -100,7 +100,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, erro
|
||||
|
||||
time.Sleep(totalDuration)
|
||||
|
||||
return []any{fmt.Sprintf("Wait for %v completed successfully.", totalDuration)}, nil
|
||||
return fmt.Sprintf("Wait for %v completed successfully.", totalDuration), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -114,7 +114,7 @@ type Tool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) ([]any, error) {
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
// Replace parameters
|
||||
commands, err := replaceCommandsParams(t.Commands, t.Parameters, params)
|
||||
if err != nil {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
@@ -36,6 +37,46 @@ func DecodeJSON(r io.Reader, v interface{}) error {
|
||||
return d.Decode(v)
|
||||
}
|
||||
|
||||
// ConvertNumbers traverses an interface and converts all json.Number
|
||||
// instances to int64 or float64.
|
||||
func ConvertNumbers(data any) (any, error) {
|
||||
switch v := data.(type) {
|
||||
// If it's a map, recursively convert the values.
|
||||
case map[string]any:
|
||||
for key, val := range v {
|
||||
convertedVal, err := ConvertNumbers(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v[key] = convertedVal
|
||||
}
|
||||
return v, nil
|
||||
|
||||
// If it's a slice, recursively convert the elements.
|
||||
case []any:
|
||||
for i, val := range v {
|
||||
convertedVal, err := ConvertNumbers(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v[i] = convertedVal
|
||||
}
|
||||
return v, nil
|
||||
|
||||
// If it's a json.Number, convert it to float or int
|
||||
case json.Number:
|
||||
// Check for a decimal point to decide the type.
|
||||
if strings.Contains(v.String(), ".") {
|
||||
return v.Float64()
|
||||
}
|
||||
return v.Int64()
|
||||
|
||||
// For all other types, return them as is.
|
||||
default:
|
||||
return data, nil
|
||||
}
|
||||
}
|
||||
|
||||
var _ yaml.InterfaceUnmarshalerContext = &DelayedUnmarshaler{}
|
||||
|
||||
// DelayedUnmarshaler is struct that saves the provided unmarshal function
|
||||
|
||||
@@ -135,7 +135,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -145,7 +145,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, AlloyDBPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, AlloyDBPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -167,8 +167,8 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -102,7 +102,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
teardownTable1 := setupBigQueryTable(t, ctx, client, createParamTableStmt, insertParamTableStmt, datasetName, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -137,9 +137,9 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
// Partial message; the full error message is too long.
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":"
|
||||
tableInfoWant := "[{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, false, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithCreateColArray(`["id INT64", "name STRING", "age INT64"]`),
|
||||
@@ -153,14 +153,15 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
runBigQueryGetTableInfoToolInvokeTest(t, datasetName, tableName, tableInfoWant)
|
||||
}
|
||||
|
||||
// getBigQueryParamToolInfo returns statements and param for my-param-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
// getBigQueryParamToolInfo returns statements and param for my-tool for bigquery kind
|
||||
func getBigQueryParamToolInfo(tableName string) (string, string, string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (id INT64, name STRING);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, NULL);`, tableName)
|
||||
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? OR name = ? ORDER BY id;`, tableName)
|
||||
toolStatement2 := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? ORDER BY id;`, tableName)
|
||||
idToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE id = ? ORDER BY id;`, tableName)
|
||||
nameToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE name = ? ORDER BY id;`, tableName)
|
||||
arrayToolStatememt := fmt.Sprintf(`SELECT * FROM %s WHERE id IN UNNEST(@idArray) AND name IN UNNEST(@nameArray) ORDER BY id;`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: int64(1)}, {Value: "Alice"},
|
||||
@@ -168,7 +169,7 @@ func getBigQueryParamToolInfo(tableName string) (string, string, string, string,
|
||||
{Value: int64(3)}, {Value: "Sid"},
|
||||
{Value: int64(4)},
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatememt, params
|
||||
return createStatement, insertStatement, toolStatement, idToolStatement, nameToolStatement, arrayToolStatememt, params
|
||||
}
|
||||
|
||||
// getBigQueryAuthToolInfo returns statements and param of my-auth-tool for bigquery kind
|
||||
@@ -380,6 +381,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"}`)),
|
||||
want: `"Operation completed successfully."`,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
@@ -390,11 +392,20 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
want: invokeParamWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-exec-sql-tool with no matching rows",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"SELECT * FROM %s WHERE id = 999\"}", tableNameParam))),
|
||||
want: `"The query returned 0 rows."`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-exec-sql-tool drop table",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)),
|
||||
want: `"Operation completed successfully."`,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
@@ -402,7 +413,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"INSERT INTO %s (id, name) VALUES (4, 'test_name')\"}", tableNameParam))),
|
||||
want: "null",
|
||||
want: `"Operation completed successfully."`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -79,7 +79,8 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
// Do not change the shape of statement without checking tests/common_test.go.
|
||||
// The structure and value of seed data has to match https://github.com/googleapis/genai-toolbox/blob/4dba0df12dc438eca3cb476ef52aa17cdf232c12/tests/common_test.go#L200-L251
|
||||
paramTestStatement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id OR CAST(cf['name'] AS string) = @name;", tableName)
|
||||
paramTestStatement2 := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id;", tableName)
|
||||
idParamTestStatement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE TO_INT64(cf['id']) = @id;", tableName)
|
||||
nameParamTestStatement := fmt.Sprintf("SELECT TO_INT64(cf['id']) as id, CAST(cf['name'] AS string) as name, FROM %s WHERE CAST(cf['name'] AS string) = @name;", tableName)
|
||||
arrayTestStatement := fmt.Sprintf(
|
||||
"SELECT TO_INT64(cf['id']) AS id, CAST(cf['name'] AS string) AS name FROM %s WHERE TO_INT64(cf['id']) IN UNNEST(@idArray) AND CAST(cf['name'] AS string) IN UNNEST(@nameArray);",
|
||||
tableName,
|
||||
@@ -98,7 +99,7 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
defer teardownTableTmpl(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigtableToolKind, paramTestStatement, paramTestStatement2, arrayTestStatement, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigtableToolKind, paramTestStatement, idParamTestStatement, nameParamTestStatement, arrayTestStatement, authToolStatement)
|
||||
toolsFile = addTemplateParamConfig(t, toolsFile)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -120,9 +121,9 @@ func TestBigtableToolEndpoints(t *testing.T) {
|
||||
// Actual test parameters are set in https://github.com/googleapis/genai-toolbox/blob/52b09a67cb40ac0c5f461598b4673136699a3089/tests/tool_test.go#L250
|
||||
select1Want := "[{\"$col1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to prepare statement: rpc error: code = InvalidArgument desc = Syntax error: Unexpected identifier \"SELEC\" [at 1:1]"}],"isError":true}}`
|
||||
invokeParamWant, _, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
invokeParamWantNull := `[{"id":4,"name":""}]`
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
invokeParamWant, _, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
invokeIdNullWant := `[{"id":4,"name":""}]`
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
|
||||
@@ -129,7 +129,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -139,7 +139,7 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMSSQLToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMSSQLToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -161,8 +161,8 @@ func TestCloudSQLMSSQLToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -116,7 +116,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLMySQLToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLMySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -148,8 +148,8 @@ func TestCloudSQLMySQLToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -120,7 +120,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -130,7 +130,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CloudSQLPostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CloudSQLPostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -152,8 +152,8 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
)
|
||||
|
||||
// GetToolsConfig returns a mock tools config file
|
||||
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, paramToolStatement2, arrayToolStatement, authToolStatement string) map[string]any {
|
||||
func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement string) map[string]any {
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
@@ -47,7 +47,7 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, p
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
"my-param-tool": map[string]any{
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
@@ -65,11 +65,11 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, p
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"my-tool-by-id": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": paramToolStatement2,
|
||||
"statement": idParamToolStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "id",
|
||||
@@ -78,6 +78,20 @@ func GetToolsConfig(sourceConfig map[string]any, toolKind, paramToolStatement, p
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-tool-by-name": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"statement": nameParamToolStmt,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
"required": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-array-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
@@ -301,15 +315,16 @@ func AddMSSQLExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an
|
||||
return config
|
||||
}
|
||||
|
||||
// GetPostgresSQLParamToolInfo returns statements and param for my-param-tool postgres-sql kind
|
||||
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, string, string, []any) {
|
||||
// GetPostgresSQLParamToolInfo returns statements and param for my-tool postgres-sql kind
|
||||
func GetPostgresSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id SERIAL PRIMARY KEY, name TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES ($1), ($2), ($3), ($4);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = $1;", tableName)
|
||||
idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1;", tableName)
|
||||
nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = $1;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY($1) AND name = ANY($2);", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatement, params
|
||||
return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// GetPostgresSQLAuthToolInfo returns statements and param of my-auth-tool for postgres-sql kind
|
||||
@@ -328,15 +343,16 @@ func GetPostgresSQLTmplToolStatement() (string, string) {
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
// GetMSSQLParamToolInfo returns statements and param for my-param-tool mssql-sql kind
|
||||
func GetMSSQLParamToolInfo(tableName string) (string, string, string, string, string, []any) {
|
||||
// GetMSSQLParamToolInfo returns statements and param for my-tool mssql-sql kind
|
||||
func GetMSSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (@alice), (@jane), (@sid), (@nil);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = @id;", tableName)
|
||||
idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id;", tableName)
|
||||
nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = @name;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY(@idArray) OR name = ANY(@p2);", tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid"), sql.Named("nil", nil)}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatement, params
|
||||
return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// GetMSSQLAuthToolInfo returns statements and param of my-auth-tool for mssql-sql kind
|
||||
@@ -355,15 +371,16 @@ func GetMSSQLTmplToolStatement() (string, string) {
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
// GetMySQLParamToolInfo returns statements and param for my-param-tool mysql-sql kind
|
||||
func GetMySQLParamToolInfo(tableName string) (string, string, string, string, string, []any) {
|
||||
// GetMySQLParamToolInfo returns statements and param for my-tool mysql-sql kind
|
||||
func GetMySQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255));", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?), (?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = ?;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY(?) AND name = ANY(?);", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatement, params
|
||||
return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// GetMySQLAuthToolInfo returns statements and param of my-auth-tool for mysql-sql kind
|
||||
@@ -382,11 +399,12 @@ func GetMySQLTmplToolStatement() (string, string) {
|
||||
return tmplSelectCombined, tmplSelectFilterCombined
|
||||
}
|
||||
|
||||
func GetNonSpannerInvokeParamWant() (string, string, string) {
|
||||
func GetNonSpannerInvokeParamWant() (string, string, string, string) {
|
||||
invokeParamWant := "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := "[{\"id\":4,\"name\":null}]"
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`
|
||||
return invokeParamWant, invokeParamWantNull, mcpInvokeParamWant
|
||||
invokeIdNullWant := "[{\"id\":4,\"name\":null}]"
|
||||
nullWant := "null"
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`
|
||||
return invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant
|
||||
}
|
||||
|
||||
// GetPostgresWants return the expected wants for postgres
|
||||
@@ -501,13 +519,14 @@ func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, createStat
|
||||
}
|
||||
|
||||
// GetRedisWants return the expected wants for redis
|
||||
func GetRedisValkeyWants() (string, string, string, string, string) {
|
||||
func GetRedisValkeyWants() (string, string, string, string, string, string) {
|
||||
select1Want := "[\"PONG\"]"
|
||||
failInvocationWant := `unknown command 'SELEC 1;', with args beginning with: \""}]}}`
|
||||
invokeParamWant := "[{\"id\":\"1\",\"name\":\"Alice\"},{\"id\":\"3\",\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := `[{"id":"4","name":""}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
return select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant
|
||||
invokeIdNullWant := `[{"id":"4","name":""}]`
|
||||
nullWant := `["null"]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
return select1Want, failInvocationWant, invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant
|
||||
}
|
||||
|
||||
func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map[string]any {
|
||||
@@ -528,7 +547,7 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"commands": [][]string{{"PING"}},
|
||||
},
|
||||
"my-param-tool": map[string]any{
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
@@ -546,7 +565,7 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"my-tool-by-id": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
@@ -559,6 +578,20 @@ func GetRedisValkeyToolsConfig(sourceConfig map[string]any, toolKind string) map
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-tool-by-name": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test invocation with params.",
|
||||
"commands": [][]string{{"GET", "null"}},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "name",
|
||||
"type": "string",
|
||||
"description": "user name",
|
||||
"required": false,
|
||||
},
|
||||
},
|
||||
},
|
||||
"my-array-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
|
||||
@@ -103,7 +103,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
collectionNameTemplateParam := "template_param_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// Set up data for param tool
|
||||
paramToolStatement, paramToolStmt2, arrayToolStatement, paramTestParams := getCouchbaseParamToolInfo(collectionNameParam)
|
||||
paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, paramTestParams := getCouchbaseParamToolInfo(collectionNameParam)
|
||||
teardownCollection1 := setupCouchbaseCollection(t, ctx, cluster, couchbaseBucket, couchbaseScope, collectionNameParam, paramTestParams)
|
||||
defer teardownCollection1(t)
|
||||
|
||||
@@ -118,7 +118,7 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
defer teardownCollection3(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement, paramToolStmt2, arrayToolStatement, authToolStatement)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, couchbaseToolKind, paramToolStatement, idParamToolStmt, nameParamToolStmt, arrayToolStatement, authToolStatement)
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, couchbaseToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmplSelectAll)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
@@ -140,8 +140,8 @@ func TestCouchbaseToolEndpoints(t *testing.T) {
|
||||
select1Want := "[{\"$1\":1}]"
|
||||
failMcpInvocationWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to execute query: parsing failure | {\\\"statement\\\":\\\"SELEC 1;\\\""
|
||||
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failMcpInvocationWant)
|
||||
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
@@ -230,27 +230,27 @@ func setupCouchbaseCollection(t *testing.T, ctx context.Context, cluster *gocb.C
|
||||
}
|
||||
}
|
||||
|
||||
// getCouchbaseParamToolInfo returns statements and params for my-param-tool couchbase-sql kind
|
||||
func getCouchbaseParamToolInfo(collectionName string) (string, string, string, []map[string]any) {
|
||||
// getCouchbaseParamToolInfo returns statements and params for my-tool couchbase-sql kind
|
||||
func getCouchbaseParamToolInfo(collectionName string) (string, string, string, string, []map[string]any) {
|
||||
// N1QL uses positional or named parameters with $ prefix
|
||||
toolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE meta().id = TOSTRING($id) OR name = $name order by meta().id",
|
||||
collectionName, collectionName)
|
||||
|
||||
toolStatement2 := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
idToolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE meta().id = TOSTRING($id) order by meta().id",
|
||||
collectionName, collectionName)
|
||||
|
||||
nameToolStatement := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE name = $name order by meta().id",
|
||||
collectionName, collectionName)
|
||||
arrayToolStatemnt := fmt.Sprintf("SELECT TONUMBER(meta().id) as id, "+
|
||||
"%s.* FROM %s WHERE TONUMBER(meta().id) IN $idArray AND name IN $nameArray order by meta().id", collectionName, collectionName)
|
||||
|
||||
params := []map[string]any{
|
||||
{"name": "Alice"},
|
||||
{"name": "Jane"},
|
||||
{"name": "Sid"},
|
||||
{"name": nil},
|
||||
}
|
||||
return toolStatement, toolStatement2, arrayToolStatemnt, params
|
||||
return toolStatement, idToolStatement, nameToolStatement, arrayToolStatemnt, params
|
||||
}
|
||||
|
||||
// getCouchbaseAuthToolInfo returns statements and param of my-auth-tool for couchbase-sql kind
|
||||
|
||||
@@ -141,7 +141,7 @@ func TestDgraphToolEndpoints(t *testing.T) {
|
||||
name: "invoke my-simple-dql-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-dql-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"result\":[{\"constant\":1}]}]",
|
||||
want: "{\"result\":[{\"constant\":1}]}",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
|
||||
1031
tests/firestore/firestore_integration_test.go
Normal file
1031
tests/firestore/firestore_integration_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -60,8 +60,10 @@ func multiTool(w http.ResponseWriter, r *http.Request) {
|
||||
handleTool0(w, r)
|
||||
case "tool1":
|
||||
handleTool1(w, r)
|
||||
case "tool1a":
|
||||
handleTool1a(w, r)
|
||||
case "tool1id":
|
||||
handleTool1Id(w, r)
|
||||
case "tool1name":
|
||||
handleTool1Name(w, r)
|
||||
case "tool2":
|
||||
handleTool2(w, r)
|
||||
case "tool3":
|
||||
@@ -80,10 +82,7 @@ func handleTool0(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
response := []string{
|
||||
"Hello",
|
||||
"World",
|
||||
}
|
||||
response := "hello world"
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
|
||||
@@ -134,7 +133,7 @@ func handleTool1(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool1a(w http.ResponseWriter, r *http.Request) {
|
||||
func handleTool1Id(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
@@ -154,6 +153,27 @@ func handleTool1a(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool1Name(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
if r.Method != http.MethodGet {
|
||||
errorMessage := fmt.Sprintf("expected GET method but got: %s", string(r.Method))
|
||||
http.Error(w, errorMessage, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
name := r.URL.Query().Get("name")
|
||||
if name == "" {
|
||||
response := "null"
|
||||
_, err := w.Write([]byte(response))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// handler function for the test server
|
||||
func handleTool2(w http.ResponseWriter, r *http.Request) {
|
||||
// expect GET method
|
||||
@@ -164,7 +184,7 @@ func handleTool2(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
email := r.URL.Query().Get("email")
|
||||
if email != "" {
|
||||
response := `{"name":"Alice"}`
|
||||
response := `[{"name":"Alice"}]`
|
||||
_, err := w.Write([]byte(response))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to write response", http.StatusInternalServerError)
|
||||
@@ -246,10 +266,7 @@ func handleTool3(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Return a JSON array as the response
|
||||
response := []any{
|
||||
"Hello", "World",
|
||||
}
|
||||
response := "hello world"
|
||||
err = json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to encode JSON", http.StatusInternalServerError)
|
||||
@@ -284,10 +301,11 @@ func TestHttpToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
select1Want := `["Hello","World"]`
|
||||
invokeParamWant, invokeParamWantNull, _ := tests.GetNonSpannerInvokeParamWant()
|
||||
select1Want := `"hello world"`
|
||||
invokeParamWant, invokeIdNullWant, _, _ := tests.GetNonSpannerInvokeParamWant()
|
||||
nullWant := "null"
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
runAdvancedHTTPInvokeTest(t)
|
||||
}
|
||||
|
||||
@@ -307,7 +325,7 @@ func runAdvancedHTTPInvokeTest(t *testing.T) {
|
||||
api: "http://127.0.0.1:5000/api/tool/my-advanced-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"animalArray": ["rabbit", "ostrich", "whale"], "id": 3, "path": "tool3", "country": "US", "X-Other-Header": "test"}`)),
|
||||
want: `["Hello","World"]`,
|
||||
want: `"hello world"`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -391,7 +409,7 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string
|
||||
"requestBody": "{}",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"my-param-tool": map[string]any{
|
||||
"my-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
@@ -407,16 +425,26 @@ func getHTTPToolsConfig(sourceConfig map[string]any, toolKind string) map[string
|
||||
"bodyParams": []tools.Parameter{tools.NewStringParameter("name", "user name")},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-param-tool2": map[string]any{
|
||||
"my-tool-by-id": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
"path": "/tool1a",
|
||||
"path": "/tool1id",
|
||||
"description": "some description",
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewIntParameter("id", "user ID")},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-tool-by-name": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
"method": "GET",
|
||||
"path": "/tool1name",
|
||||
"description": "some description",
|
||||
"queryParams": []tools.Parameter{
|
||||
tools.NewStringParameterWithRequired("name", "user name", false)},
|
||||
"headers": map[string]string{"Content-Type": "application/json"},
|
||||
},
|
||||
"my-auth-tool": map[string]any{
|
||||
"kind": toolKind,
|
||||
"source": "my-instance",
|
||||
|
||||
@@ -102,7 +102,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMSSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMsSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQLToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MSSQLToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddMSSQLExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMSSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MSSQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -134,8 +134,8 @@ func TestMSSQLToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMSSQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -93,7 +93,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -103,7 +103,7 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MySQLToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, MySQLToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddMySqlExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, MySQLToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -125,8 +125,8 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetMySQLWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -101,7 +101,7 @@ func TestNeo4jToolEndpoints(t *testing.T) {
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
name: "get my-simple-cypher-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-cypher-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-cypher-tool": map[string]any{
|
||||
|
||||
@@ -99,7 +99,7 @@ func TestPostgres(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetPostgresSQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
@@ -109,7 +109,7 @@ func TestPostgres(t *testing.T) {
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, PostgresToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = tests.AddPgExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, PostgresToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -131,8 +131,8 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := tests.GetPostgresWants()
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
|
||||
@@ -99,18 +99,19 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
}
|
||||
|
||||
func setupRedisDB(t *testing.T, ctx context.Context, client *redis.Client) func(*testing.T) {
|
||||
keys := []string{"row1", "row2", "row3", "row4"}
|
||||
keys := []string{"row1", "row2", "row3", "row4", "null"}
|
||||
commands := [][]any{
|
||||
{"HSET", keys[0], "id", 1, "name", "Alice"},
|
||||
{"HSET", keys[1], "id", 2, "name", "Jane"},
|
||||
{"HSET", keys[2], "id", 3, "name", "Sid"},
|
||||
{"HSET", keys[3], "id", 4, "name", nil},
|
||||
{"SET", keys[4], "null"},
|
||||
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
|
||||
}
|
||||
for _, c := range commands {
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := getSpannerParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getSpannerParamToolInfo(tableNameParam)
|
||||
dbString := fmt.Sprintf(
|
||||
"projects/%s/instances/%s/databases/%s",
|
||||
SpannerProject,
|
||||
@@ -129,7 +129,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
defer teardownTableTmpl(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SpannerToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SpannerToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addSpannerExecuteSqlConfig(t, toolsFile)
|
||||
toolsFile = addSpannerReadOnlyConfig(t, toolsFile)
|
||||
toolsFile = addTemplateParamConfig(t, toolsFile)
|
||||
@@ -153,11 +153,12 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
select1Want := "[{\"\":\"1\"}]"
|
||||
accessSchemaWant := "[{\"schema_name\":\"INFORMATION_SCHEMA\"}]"
|
||||
invokeParamWant := "[{\"id\":\"1\",\"name\":\"Alice\"},{\"id\":\"3\",\"name\":\"Sid\"}]"
|
||||
invokeParamWantNull := `[{"id":"4","name":null}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-param-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
invokeIdNullWant := `[{"id":"4","name":null}]`
|
||||
mcpInvokeParamWant := `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":\"1\",\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":\"3\",\"name\":\"Sid\"}"}]}}`
|
||||
nullWant := "null"
|
||||
failInvocationWant := `"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute client: unable to parse row: spanner: code = \"InvalidArgument\", desc = \"Syntax error: Unexpected identifier \\\\\\\"SELEC\\\\\\\" [at 1:1]\\\\nSELEC 1;\\\\n^\"`
|
||||
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
runSpannerSchemaToolInvokeTest(t, accessSchemaWant)
|
||||
runSpannerExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, tableNameAuth)
|
||||
@@ -170,15 +171,16 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, templateParamTestConfig)
|
||||
}
|
||||
|
||||
// getSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
|
||||
func getSpannerParamToolInfo(tableName string) (string, string, string, string, string, map[string]any) {
|
||||
// getSpannerToolInfo returns statements and param for my-tool for spanner-sql kind
|
||||
func getSpannerParamToolInfo(tableName string) (string, string, string, string, string, string, map[string]any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3), (4, @name4)", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = @id", tableName)
|
||||
idToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id", tableName)
|
||||
nameToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = @name", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id IN UNNEST(@idArray) AND name IN UNNEST(@nameArray)", tableName)
|
||||
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid", "name4": nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatement, params
|
||||
return createStatement, insertStatement, toolStatement, idToolStatement, nameToolStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// getSpannerAuthToolInfo returns statements and param of my-auth-tool for spanner-sql kind
|
||||
|
||||
@@ -81,14 +81,15 @@ func setupSQLiteTestDB(t *testing.T, ctx context.Context, db *sql.DB, createStat
|
||||
}
|
||||
}
|
||||
|
||||
func getSQLiteParamToolInfo(tableName string) (string, string, string, string, string, []any) {
|
||||
func getSQLiteParamToolInfo(tableName string) (string, string, string, string, string, string, []any) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id INTEGER PRIMARY KEY, name TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (name) VALUES (?), (?), (?), (?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ? OR name = ?;", tableName)
|
||||
toolStatement2 := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
idToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ?;", tableName)
|
||||
nameToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = ?;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY({{.idArray}}) AND name = ANY({{.nameArray}});", tableName)
|
||||
params := []any{"Alice", "Jane", "Sid", nil}
|
||||
return createStatement, insertStatement, toolStatement, toolStatement2, arrayToolStatement, params
|
||||
return createStatement, insertStatement, toolStatement, idToolStatement, nameToolStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
func getSQLiteAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
@@ -126,7 +127,7 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, paramToolStmt2, arrayToolStmt, paramTestParams := getSQLiteParamToolInfo(tableNameParam)
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getSQLiteParamToolInfo(tableNameParam)
|
||||
setupSQLiteTestDB(t, ctx, db, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
|
||||
// set up data for auth tool
|
||||
@@ -134,7 +135,7 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
setupSQLiteTestDB(t, ctx, db, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SQLiteToolKind, paramToolStmt, paramToolStmt2, arrayToolStmt, authToolStmt)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, SQLiteToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getSQLiteTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, SQLiteToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
@@ -156,8 +157,8 @@ func TestSQLiteToolEndpoint(t *testing.T) {
|
||||
|
||||
select1Want := "[{\"1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: SQL logic error: near \"SELEC\": syntax error (1)"}],"isError":true}}`
|
||||
invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, false)
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func RunToolGetTest(t *testing.T) {
|
||||
}
|
||||
|
||||
// RunToolInvoke runs the tool invoke endpoint
|
||||
func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeParamWantNull string, supportsArray bool) {
|
||||
func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeIdNullWant, nullString string, supportNullParam, supportsArray bool) {
|
||||
// Get ID token
|
||||
idToken, err := GetGoogleIdToken(ClientId)
|
||||
if err != nil {
|
||||
@@ -101,31 +101,39 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeParamWa
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
name: "invoke my-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
||||
want: invokeParamWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-param-tool2 with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool2/invoke",
|
||||
name: "invoke my-tool-by-id with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-id/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 4}`)),
|
||||
want: invokeParamWantNull,
|
||||
want: invokeIdNullWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-param-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
name: "invoke my-tool-by-name with nil response",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool-by-name/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: nullString,
|
||||
isErr: !supportNullParam,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-param-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-param-tool/invoke",
|
||||
name: "Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 1}`)),
|
||||
isErr: true,
|
||||
@@ -629,17 +637,17 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "MCP Invoke my-param-tool",
|
||||
name: "MCP Invoke my-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "my-param-tool",
|
||||
Id: "my-tool",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-param-tool",
|
||||
"name": "my-tool",
|
||||
"arguments": map[string]any{
|
||||
"id": int(3),
|
||||
"name": "Alice",
|
||||
@@ -666,7 +674,7 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
|
||||
want: `{"jsonrpc":"2.0","id":"invalid-tool","error":{"code":-32602,"message":"invalid tool name: tool with name \"foo\" does not exist"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-param-tool without parameters",
|
||||
name: "MCP Invoke my-tool without parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
@@ -676,14 +684,14 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-param-tool",
|
||||
"name": "my-tool",
|
||||
"arguments": map[string]any{},
|
||||
},
|
||||
},
|
||||
want: `{"jsonrpc":"2.0","id":"invoke-without-parameter","error":{"code":-32602,"message":"provided parameters were invalid: parameter \"id\" is required"}}`,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-param-tool with insufficient parameters",
|
||||
name: "MCP Invoke my-tool with insufficient parameters",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
@@ -693,7 +701,7 @@ func RunMCPToolCallMethod(t *testing.T, invokeParamWant, failInvocationWant stri
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": "my-param-tool",
|
||||
"name": "my-tool",
|
||||
"arguments": map[string]any{"id": 1},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -102,18 +102,19 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeParamWantNull, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeParamWantNull, true)
|
||||
select1Want, failInvocationWant, invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetRedisValkeyWants()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
}
|
||||
|
||||
func setupValkeyDB(t *testing.T, ctx context.Context, client valkey.Client) func(*testing.T) {
|
||||
keys := []string{"row1", "row2", "row3", "row4"}
|
||||
keys := []string{"row1", "row2", "row3", "row4", "null"}
|
||||
commands := [][]string{
|
||||
{"HSET", keys[0], "name", "Alice", "id", "1"},
|
||||
{"HSET", keys[1], "name", "Jane", "id", "2"},
|
||||
{"HSET", keys[2], "name", "Sid", "id", "3"},
|
||||
{"HSET", keys[3], "name", "", "id", "4"},
|
||||
{"SET", keys[4], "null"},
|
||||
{"HSET", tests.ServiceAccountEmail, "name", "Alice"},
|
||||
}
|
||||
builtCmds := make(valkey.Commands, len(commands))
|
||||
|
||||
Reference in New Issue
Block a user