feat(cassandra): add Cassandra Source and Tool (#1012)

[Cassandra](https://cassandra.apache.org/_/cassandra-basics.html) is a
NoSQL distributed database. By design, NoSQL databases are lightweight,
open-source, non-relational, and largely distributed. Counted among
their strengths are horizontal scalability, distributed architectures,
and a flexible approach to schema definition.

Cassandra go driver link -
https://pkg.go.dev/github.com/apache/cassandra-gocql-driver/v2

This PR 
- adds a new source for cassandra
- adds a new tool _cassandra-cql_ with support for executing predefined
parameterized CQL queries on cassandra
- adds unit and integration tests for the tool and the source
- adds documentation for the cassandra source and cassandra-cql tool

---------

Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com>
Co-authored-by: duwenxin <duwenxin@google.com>
This commit is contained in:
Pranava B
2025-09-24 02:18:43 +05:30
committed by GitHub
parent 6387dd3efa
commit 6e420534ee
18 changed files with 1206 additions and 9 deletions

View File

@@ -662,6 +662,26 @@ steps:
- |
./yugabytedb.test -test.v
- id: "cassandra"
name: golang:1
waitFor: ["compile-test-binary"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID", "CASSANDRA_USER", "CASSANDRA_PASS", "CASSANDRA_HOST"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
.ci/test_with_coverage.sh \
"Cassandra" \
cassandra \
cassandra
availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
@@ -746,6 +766,12 @@ availableSecrets:
env: YUGABYTEDB_USER
- versionName: projects/$PROJECT_ID/secrets/yugabytedb_pass/versions/latest
env: YUGABYTEDB_PASS
- versionName: projects/$PROJECT_ID/secrets/cassandra_user/versions/latest
env: CASSANDRA_USER
- versionName: projects/$PROJECT_ID/secrets/cassandra_pass/versions/latest
env: CASSANDRA_PASS
- versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest
env: CASSANDRA_HOST
options:
logging: CLOUD_LOGGING_ONLY

View File

@@ -64,6 +64,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
_ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
@@ -159,6 +160,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"

View File

@@ -0,0 +1,57 @@
---
title: "Cassandra"
type: docs
weight: 1
description: >
Cassandra is a NoSQL distributed database known for its horizontal scalability, distributed architecture, and flexible schema definition.
---
## About
[Cassandra][cassandra-docs] is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition.
[cassandra-docs]: https://cassandra.apache.org/
## Available Tools
- [`cassandra-cql`](../tools/cassandra/cassandra-cql.md)
Run parameterized CQL queries in Cassandra.
## Example
```yaml
sources:
my-cassandra-source:
kind: cassandra
hosts:
- 127.0.0.1
keyspace: my_keyspace
protoVersion: 4
username: ${USER_NAME}
password: ${PASSWORD}
caPath: /path/to/ca.crt # Optional: path to CA certificate
certPath: /path/to/client.crt # Optional: path to client certificate
keyPath: /path/to/client.key # Optional: path to client key
enableHostVerification: true # Optional: enable host verification
```
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your secrets into the configuration file.
{{< /notice >}}
## Reference
| **field** | **type** | **required** | **description** |
|------------------------|:---------:|:------------:|-------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "cassandra". |
| hosts | string[] | true | List of IP addresses to connect to (e.g., ["192.168.1.1:9042", "192.168.1.2:9042","192.168.1.3:9042"]). The default port is 9042 if not specified. |
| keyspace | string | true | Name of the Cassandra keyspace to connect to (e.g., "my_keyspace"). |
| protoVersion | integer | false | Protocol version for the Cassandra connection (e.g., 4). |
| username | string | false | Name of the Cassandra user to connect as (e.g., "my-cassandra-user"). |
| password | string | false | Password of the Cassandra user (e.g., "my-password"). |
| caPath | string | false | Path to the CA certificate for SSL/TLS (e.g., "/path/to/ca.crt"). |
| certPath | string | false | Path to the client certificate for SSL/TLS (e.g., "/path/to/client.crt"). |
| keyPath | string | false | Path to the client key for SSL/TLS (e.g., "/path/to/client.key"). |
| enableHostVerification | boolean | false | Enable host verification for SSL/TLS (e.g., true). By default, host verification is disabled. |

View File

@@ -0,0 +1,7 @@
---
title: "Cassandra"
type: docs
weight: 1
description: >
Tools that work with Cassandra Sources.
---

View File

@@ -0,0 +1,96 @@
---
title: "cassandra-cql"
type: docs
weight: 1
description: >
A "cassandra-cql" tool executes a pre-defined CQL statement against a Cassandra
database.
aliases:
- /resources/tools/cassandra-cql
---
## About
A `cassandra-cql` tool executes a pre-defined CQL statement against a Cassandra
database. It's compatible with any of the following sources:
- [cassandra](../sources/cassandra.md)
The specified CQL statement is executed as a [prepared statement][cassandra-prepare],
and expects parameters in the CQL query to be in the form of placeholders `?`.
[cassandra-prepare]: https://docs.datastax.com/en/developer/go-driver/4.8/cql-prepared-statements/
## Example
> **Note:** This tool uses parameterized queries to prevent CQL injections.
> Query parameters can be used as substitutes for arbitrary expressions.
> Parameters cannot be used as substitutes for keyspaces, table names, column names,
> or other parts of the query.
```yaml
tools:
search_users_by_email:
kind: cassandra-cql
source: my-cassandra-cluster
statement: |
SELECT user_id, email, first_name, last_name, created_at
FROM users
WHERE email = ?
description: |
Use this tool to retrieve specific user information by their email address.
Takes an email address and returns user details including user ID, email,
first name, last name, and account creation timestamp.
Do NOT use this tool with a user ID or other identifiers.
Example:
{{
"email": "user@example.com",
}}
parameters:
- name: email
type: string
description: User's email address
```
### Example with Template Parameters
> **Note:** This tool allows direct modifications to the CQL statement,
> including keyspaces, table names, and column names. **This makes it more
> vulnerable to CQL injections**. Using basic parameters only (see above) is
> recommended for performance and safety reasons. For more details, please check
> [templateParameters](../#template-parameters).
```yaml
tools:
list_keyspace_table:
kind: cassandra-cql
source: my-cassandra-cluster
statement: |
SELECT * FROM {{.keyspace}}.{{.tableName}};
description: |
Use this tool to list all information from a specific table in a keyspace.
Example:
{{
"keyspace": "my_keyspace",
"tableName": "users",
}}
templateParameters:
- name: keyspace
type: string
description: Keyspace containing the table
- name: tableName
type: string
description: Table to select from
```
## Reference
| **field** | **type** | **required** | **description** |
|--------------------|:------------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "cassandra-cql". |
| source | string | true | Name of the source the CQL should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |
| statement | string | true | CQL statement to execute. |
| authRequired | []string | false | List of authentication requirements for the source. |
| parameters | [parameters](../#specifying-parameters) | false | List of [parameters](../#specifying-parameters) that will be inserted into the CQL statement. |
| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the CQL statement before executing prepared statement. |

3
go.mod
View File

@@ -26,6 +26,7 @@ require (
github.com/go-playground/validator/v10 v10.27.0
github.com/go-sql-driver/mysql v1.9.3
github.com/goccy/go-yaml v1.18.0
github.com/gocql/gocql v1.7.0
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.7.6
@@ -115,6 +116,7 @@ require (
github.com/gorilla/websocket v1.5.3 // indirect
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -176,6 +178,7 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.7.1 // indirect

10
go.sum
View File

@@ -737,6 +737,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4p
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -897,6 +901,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
@@ -1047,6 +1053,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVT
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
@@ -2042,6 +2050,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@@ -0,0 +1,134 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cassandra
import (
"context"
"fmt"
"github.com/goccy/go-yaml"
"github.com/gocql/gocql"
"github.com/googleapis/genai-toolbox/internal/sources"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cassandra"
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Hosts []string `yaml:"hosts" validate:"required"`
Keyspace string `yaml:"keyspace"`
ProtoVersion int `yaml:"protoVersion"`
Username string `yaml:"username"`
Password string `yaml:"password"`
CAPath string `yaml:"caPath"`
CertPath string `yaml:"certPath"`
KeyPath string `yaml:"keyPath"`
EnableHostVerification bool `yaml:"enableHostVerification"`
}
// Initialize implements sources.SourceConfig.
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
session, err := initCassandraSession(ctx, tracer, c)
if err != nil {
return nil, fmt.Errorf("unable to create session: %v", err)
}
s := &Source{
Name: c.Name,
Kind: SourceKind,
Session: session,
}
return s, nil
}
// SourceConfigKind implements sources.SourceConfig.
func (c Config) SourceConfigKind() string {
return SourceKind
}
var _ sources.SourceConfig = Config{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Session *gocql.Session
}
// CassandraSession implements cassandra.compatibleSource.
func (s *Source) CassandraSession() *gocql.Session {
return s.Session
}
// SourceKind implements sources.Source.
func (s Source) SourceKind() string {
return SourceKind
}
var _ sources.Source = &Source{}
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
defer span.End()
// Validate authentication configuration
if c.Password != "" && c.Username == "" {
return nil, fmt.Errorf("invalid Cassandra configuration: password provided without a username")
}
cluster := gocql.NewCluster(c.Hosts...)
cluster.ProtoVersion = c.ProtoVersion
cluster.Keyspace = c.Keyspace
// Configure authentication if username is provided
if c.Username != "" {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: c.Username,
Password: c.Password,
}
}
// Configure SSL options if any are specified
if c.CAPath != "" || c.CertPath != "" || c.KeyPath != "" || c.EnableHostVerification {
cluster.SslOpts = &gocql.SslOptions{
CaPath: c.CAPath,
CertPath: c.CertPath,
KeyPath: c.KeyPath,
EnableHostVerification: c.EnableHostVerification,
}
}
// Create session
session, err := cluster.CreateSession()
if err != nil {
return nil, fmt.Errorf("failed to create Cassandra session: %w", err)
}
return session, nil
}

View File

@@ -0,0 +1,158 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cassandra_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlCassandra(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example (without optional fields)",
in: `
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
`,
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "",
Password: "",
ProtoVersion: 0,
CAPath: "",
CertPath: "",
KeyPath: "",
Keyspace: "",
EnableHostVerification: false,
},
},
},
{
desc: "with optional fields",
in: `
sources:
my-cassandra-instance:
kind: cassandra
hosts:
- "my-host1"
- "my-host2"
username: "user"
password: "pass"
keyspace: "example_keyspace"
protoVersion: 4
caPath: "path/to/ca.crt"
certPath: "path/to/cert"
keyPath: "path/to/key"
enableHostVerification: true
`,
want: server.SourceConfigs{
"my-cassandra-instance": cassandra.Config{
Name: "my-cassandra-instance",
Kind: cassandra.SourceKind,
Hosts: []string{"my-host1", "my-host2"},
Username: "user",
Password: "pass",
Keyspace: "example_keyspace",
ProtoVersion: 4,
CAPath: "path/to/ca.crt",
CertPath: "path/to/cert",
KeyPath: "path/to/key",
EnableHostVerification: true,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "extra field",
in: `
sources:
my-cassandra-instance:
kind: cassandra
host:
- "my-host"
foo: bar
`,
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host:\n 3 | - my-host\n 4 | kind: cassandra",
},
{
desc: "missing required field",
in: `
sources:
my-cassandra-instance:
kind: cassandra
`,
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}

View File

@@ -0,0 +1,182 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cassandracql
import (
"context"
"fmt"
yaml "github.com/goccy/go-yaml"
"github.com/gocql/gocql"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const kind string = "cassandra-cql"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
CassandraSession() *gocql.Session
}
var _ compatibleSource = &cassandra.Source{}
var compatibleSources = [...]string{cassandra.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Statement string `yaml:"statement" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
}
// Initialize implements tools.ToolConfig.
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[c.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", c.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters, paramManifest, paramMcpManifest, err := tools.ProcessParameters(c.TemplateParameters, c.Parameters)
if err != nil {
return nil, err
}
mcpManifest := tools.McpManifest{
Name: c.Name,
Description: c.Description,
InputSchema: paramMcpManifest,
}
t := Tool{
Name: c.Name,
Kind: kind,
Parameters: c.Parameters,
TemplateParameters: c.TemplateParameters,
AllParams: allParameters,
Statement: c.Statement,
AuthRequired: c.AuthRequired,
Session: s.CassandraSession(),
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}
// ToolConfigKind implements tools.ToolConfig.
func (c Config) ToolConfigKind() string {
return kind
}
var _ tools.ToolConfig = Config{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
TemplateParameters tools.Parameters `yaml:"templateParameters"`
AllParams tools.Parameters `yaml:"allParams"`
Session *gocql.Session
Statement string
manifest tools.Manifest
mcpManifest tools.McpManifest
}
// RequiresClientAuthorization implements tools.Tool.
func (t Tool) RequiresClientAuthorization() bool {
return false
}
// Authorized implements tools.Tool.
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
// Invoke implements tools.Tool.
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract template params %w", err)
}
newParams, err := tools.GetParams(t.Parameters, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
iter := t.Session.Query(newStatement, sliceParams...).WithContext(ctx).Iter()
// Create a slice to store the out
var out []map[string]interface{}
// Scan results into a map and append to the slice
for {
row := make(map[string]interface{}) // Create a new map for each row
if !iter.MapScan(row) {
break // No more rows
}
out = append(out, row)
}
if err := iter.Close(); err != nil {
return nil, fmt.Errorf("unable to parse rows: %w", err)
}
return out, nil
}
// Manifest implements tools.Tool.
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
// McpManifest implements tools.Tool.
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
// ParseParams implements tools.Tool.
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.AllParams, data, claims)
}
var _ tools.Tool = Tool{}

View File

@@ -0,0 +1,171 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cassandracql_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
)
func TestParseFromYamlCassandra(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: cassandra-cql
source: my-cassandra-instance
description: some description
statement: |
SELECT * FROM CQL_STATEMENT;
authRequired:
- my-google-auth-service
- other-auth-service
parameters:
- name: country
type: string
description: some description
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
`,
want: server.ToolConfigs{
"example_tool": cassandracql.Config{
Name: "example_tool",
Kind: "cassandra-cql",
Source: "my-cassandra-instance",
Description: "some description",
Statement: "SELECT * FROM CQL_STATEMENT;\n",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
},
},
},
{
desc: "with template parameters",
in: `
tools:
example_tool:
kind: cassandra-cql
source: my-cassandra-instance
description: some description
statement: |
SELECT * FROM CQL_STATEMENT;
authRequired:
- my-google-auth-service
- other-auth-service
parameters:
- name: country
type: string
description: some description
authServices:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
templateParameters:
- name: tableName
type: string
description: some description.
- name: fieldArray
type: array
description: The columns to return for the query.
items:
name: column
type: string
description: A column name that will be returned from the query.
`,
want: server.ToolConfigs{
"example_tool": cassandracql.Config{
Name: "example_tool",
Kind: "cassandra-cql",
Source: "my-cassandra-instance",
Description: "some description",
Statement: "SELECT * FROM CQL_STATEMENT;\n",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
TemplateParameters: []tools.Parameter{
tools.NewStringParameter("tableName", "some description."),
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
},
},
},
},
{
desc: "without optional fields",
in: `
tools:
example_tool:
kind: cassandra-cql
source: my-cassandra-instance
description: some description
statement: |
SELECT * FROM CQL_STATEMENT;
`,
want: server.ToolConfigs{
"example_tool": cassandracql.Config{
Name: "example_tool",
Kind: "cassandra-cql",
Source: "my-cassandra-instance",
Description: "some description",
Statement: "SELECT * FROM CQL_STATEMENT;\n",
AuthRequired: []string{},
Parameters: nil,
TemplateParameters: nil,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -0,0 +1,284 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cassandra
import (
"context"
"fmt"
"log"
"os"
"regexp"
"strings"
"testing"
"time"
"github.com/gocql/gocql"
"github.com/google/uuid"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/tests"
)
var (
CassandraSourceKind = "cassandra"
CassandraToolKind = "cassandra-cql"
Hosts = os.Getenv("CASSANDRA_HOST")
Keyspace = "example_keyspace"
Username = os.Getenv("CASSANDRA_USER")
Password = os.Getenv("CASSANDRA_PASS")
)
func getCassandraVars(t *testing.T) map[string]any {
switch "" {
case Hosts:
t.Fatal("'Hosts' not set")
case Username:
t.Fatal("'Username' not set")
case Password:
t.Fatal("'Password' not set")
}
return map[string]any{
"kind": CassandraSourceKind,
"hosts": strings.Split(Hosts, ","),
"keyspace": Keyspace,
"username": Username,
"password": Password,
}
}
func initCassandraSession() (*gocql.Session, error) {
hostStrings := strings.Split(Hosts, ",")
var hosts []string
for _, h := range hostStrings {
trimmedHost := strings.TrimSpace(h)
if trimmedHost != "" {
hosts = append(hosts, trimmedHost)
}
}
if len(hosts) == 0 {
return nil, fmt.Errorf("no valid hosts found in CASSANDRA_HOSTS env var")
}
// Configure cluster connection
cluster := gocql.NewCluster(hosts...)
cluster.Consistency = gocql.Quorum
cluster.ProtoVersion = 4
cluster.DisableInitialHostLookup = true
cluster.ConnectTimeout = 10 * time.Second
cluster.NumConns = 2
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: Username,
Password: Password,
}
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
NumRetries: 3,
Min: 200 * time.Millisecond,
Max: 2 * time.Second,
}
// Create session
session, err := cluster.CreateSession()
if err != nil {
return nil, fmt.Errorf("Failed to create session: %v", err)
}
// Create keyspace
err = session.Query(fmt.Sprintf(`
CREATE KEYSPACE IF NOT EXISTS %s
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}
`, Keyspace)).Exec()
if err != nil {
return nil, fmt.Errorf("Failed to create keyspace: %v", err)
}
return session, nil
}
func initTable(tableName string, session *gocql.Session) error {
// Create table with additional columns
err := session.Query(fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s.%s (
id int PRIMARY KEY,
name text,
email text,
age int,
is_active boolean,
created_at timestamp
)
`, Keyspace, tableName)).Exec()
if err != nil {
return fmt.Errorf("Failed to create table: %v", err)
}
// Use fixed timestamps for reproducibility
fixedTime, _ := time.Parse(time.RFC3339, "2025-07-25T12:00:00Z")
dayAgo := fixedTime.Add(-24 * time.Hour)
twelveHoursAgo := fixedTime.Add(-12 * time.Hour)
// Insert minimal diverse data with fixed time.Time for timestamps
err = session.Query(fmt.Sprintf(`
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
3, "Alice", tests.ServiceAccountEmail, 25, true, dayAgo,
).Exec()
if err != nil {
return fmt.Errorf("Failed to insert user: %v", err)
}
err = session.Query(fmt.Sprintf(`
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
2, "Alex", "janedoe@gmail.com", 30, false, twelveHoursAgo,
).Exec()
if err != nil {
return fmt.Errorf("Failed to insert user: %v", err)
}
err = session.Query(fmt.Sprintf(`
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
1, "Sid", "sid@gmail.com", 10, true, fixedTime,
).Exec()
if err != nil {
return fmt.Errorf("Failed to insert user: %v", err)
}
err = session.Query(fmt.Sprintf(`
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
4, nil, "a@gmail.com", 40, false, fixedTime,
).Exec()
if err != nil {
return fmt.Errorf("Failed to insert user: %v", err)
}
return nil
}
func dropTable(session *gocql.Session, tableName string) {
err := session.Query(fmt.Sprintf("drop table %s.%s", Keyspace, tableName)).Exec()
if err != nil {
log.Printf("Failed to drop table %s: %v", tableName, err)
}
}
func TestCassandra(t *testing.T) {
session, err := initCassandraSession()
if err != nil {
t.Fatal(err)
}
defer session.Close()
sourceConfig := getCassandraVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
paramTableName := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
err = initTable(paramTableName, session)
if err != nil {
t.Fatal(err)
}
defer dropTable(session, paramTableName)
err = initTable(tableNameAuth, session)
if err != nil {
t.Fatal(err)
}
defer dropTable(session, tableNameAuth)
err = initTable(tableNameTemplateParam, session)
if err != nil {
t.Fatal(err)
}
defer dropTable(session, tableNameTemplateParam)
paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt := createParamToolInfo(paramTableName)
_, _, authToolStmt := getCassandraAuthToolInfo(tableNameAuth)
toolsFile := tests.GetToolsConfig(sourceConfig, CassandraToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
tmplSelectCombined, tmplSelectFilterCombined := getCassandraTmplToolInfo()
tmpSelectAll := "SELECT * FROM {{.tableName}} where id = 1"
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CassandraToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmpSelectAll)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, mcpSelect1Want, mcpMyToolIdWant := getCassandraWants()
selectAllWant, selectIdWant, selectNameWant := getCassandraTmplWants()
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test(),
tests.DisableOptionalNullParamTest(),
tests.WithMyToolId3NameAliceWant(selectIdNameWant),
tests.WithMyToolById4Want(selectIdNullWant),
tests.WithMyArrayToolWant(selectArrayParamWant),
tests.DisableSelect1AuthTest())
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam,
tests.DisableSelectFilterTest(),
tests.WithSelectAllWant(selectAllWant),
tests.DisableDdlTest(), tests.DisableInsertTest(), tests.WithTmplSelectId1Want(selectIdWant), tests.WithTmplSelectNameWant(selectNameWant))
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want,
tests.WithMcpMyToolId3NameAliceWant(mcpMyToolIdWant),
tests.DisableMcpSelect1AuthTest())
}
func createParamToolInfo(tableName string) (string, string, string, string) {
toolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id = ? AND name = ? ALLOW FILTERING;", tableName)
idParamStatement := fmt.Sprintf("SELECT id,name FROM %s WHERE id = ?;", tableName)
nameParamStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE name = ? ALLOW FILTERING;", tableName)
arrayToolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id IN ? AND name IN ? ALLOW FILTERING;", tableName)
return toolStatement, idParamStatement, nameParamStatement, arrayToolStatement
}
func getCassandraAuthToolInfo(tableName string) (string, string, string) {
createStatement := fmt.Sprintf("CREATE TABLE %s (id UUID PRIMARY KEY, name TEXT, email TEXT);", tableName)
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (uuid(), ?, ?), (uuid(), ?, ?);", tableName)
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ? ALLOW FILTERING;", tableName)
return createStatement, insertStatement, toolStatement
}
func getCassandraTmplToolInfo() (string, string) {
selectAllTemplateStmt := "SELECT age, id, name FROM {{.tableName}} where id = ?;"
selectByIdTemplateStmt := "SELECT id, name FROM {{.tableName}} WHERE name = ? ALLOW FILTERING;"
return selectAllTemplateStmt, selectByIdTemplateStmt
}
func getCassandraWants() (string, string, string, string, string, string) {
selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]"
selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]"
selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]"
mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}"
mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}"
return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant
}
func getCassandraTmplWants() (string, string, string) {
selectAllWant := "[{\"age\":10,\"created_at\":\"2025-07-25T12:00:00Z\",\"email\":\"sid@gmail.com\",\"id\":1,\"is_active\":true,\"name\":\"Sid\"}]"
selectIdWant := "[{\"age\":10,\"id\":1,\"name\":\"Sid\"}]"
selectNameWant := "[{\"id\":2,\"name\":\"Alex\"}]"
return selectAllWant, selectIdWant, selectNameWant
}

View File

@@ -110,6 +110,7 @@ func TestMongoDBToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.WithMyToolId3NameAliceWant(myToolId3NameAliceWant),
tests.WithMyArrayToolWant(myToolId3NameAliceWant),
tests.WithMyToolById4Want(myToolById4Want),
)
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, select1Want,

View File

@@ -21,9 +21,12 @@ type InvokeTestConfig struct {
myToolId3NameAliceWant string
myToolById4Want string
nullWant string
myArrayToolWant string
supportSelect1Want bool
supportOptionalNullParam bool
supportArrayParam bool
supportClientAuth bool
supportSelect1Auth bool
}
type InvokeTestOption func(*InvokeTestConfig)
@@ -36,6 +39,14 @@ func WithMyToolId3NameAliceWant(s string) InvokeTestOption {
}
}
// WithMyArrayToolWant represents the response value for my-array-tool.
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyArrayToolWant("custom"))
func WithMyArrayToolWant(s string) InvokeTestOption {
return func(c *InvokeTestConfig) {
c.myArrayToolWant = s
}
}
// WithMyToolById4Want represents the response value for my-tool-by-id with id=4.
// This response includes a null value column.
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolById4Want("custom"))
@@ -69,6 +80,22 @@ func DisableArrayTest() InvokeTestOption {
}
}
// DisableSelect1Test disables tests for sources that do not support SELECT 1 query.
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test())
func DisableSelect1Test() InvokeTestOption {
return func(c *InvokeTestConfig) {
c.supportSelect1Want = false
}
}
// DisableSelect1AuthTest disables auth tests for sources that do not support SELECT 1 query.
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1AuthTest())
func DisableSelect1AuthTest() InvokeTestOption {
return func(c *InvokeTestConfig) {
c.supportSelect1Auth = false
}
}
// EnableClientAuthTest runs the client authorization tests.
// Only enable it if your source supports the `useClientOAuth` configuration.
// Currently, this should only be used with the BigQuery tests.
@@ -84,6 +111,7 @@ func EnableClientAuthTest() InvokeTestOption {
type MCPTestConfig struct {
myToolId3NameAliceWant string
supportClientAuth bool
supportSelect1Auth bool
}
type McpTestOption func(*MCPTestConfig)
@@ -105,6 +133,13 @@ func EnableMcpClientAuthTest() McpTestOption {
}
}
// DisableMcpSelect1AuthTest disables the auth tool tests which use select 1.
func DisableMcpSelect1AuthTest() McpTestOption {
return func(c *MCPTestConfig) {
c.supportSelect1Auth = false
}
}
/* Configurations for RunExecuteSqlToolInvokeTest() */
// ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest()
@@ -129,6 +164,7 @@ type TemplateParameterTestConfig struct {
ddlWant string
selectAllWant string
selectId1Want string
selectNameWant string
selectEmptyWant string
insert1Want string
@@ -136,8 +172,9 @@ type TemplateParameterTestConfig struct {
nameColFilter string
createColArray string
supportDdl bool
supportInsert bool
supportDdl bool
supportInsert bool
supportSelectFields bool
}
type TemplateParamOption func(*TemplateParameterTestConfig)
@@ -166,6 +203,14 @@ func WithTmplSelectId1Want(s string) TemplateParamOption {
}
}
// WithTmplSelectNameWant represents the response value of select-filter-templateParams-combined-tool with name.
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithTmplSelectNameWant("custom"))
func WithTmplSelectNameWant(s string) TemplateParamOption {
return func(c *TemplateParameterTestConfig) {
c.selectNameWant = s
}
}
// WithSelectEmptyWant represents the response value of select-templateParams-combined-tool with no results.
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithSelectEmptyWant("custom"))
func WithSelectEmptyWant(s string) TemplateParamOption {
@@ -221,3 +266,11 @@ func DisableInsertTest() TemplateParamOption {
c.supportInsert = false
}
}
// DisableInsertTest disables tests of select-fields-templateParams-tool test.
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.DisableSelectFilterTest())
func DisableSelectFilterTest() TemplateParamOption {
return func(c *TemplateParameterTestConfig) {
c.supportSelectFields = false
}
}

View File

@@ -104,6 +104,7 @@ func TestRedisToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.WithMyToolId3NameAliceWant(invokeParamWant),
tests.WithMyArrayToolWant(invokeParamWant),
tests.WithMyToolById4Want(invokeIdNullWant),
tests.WithNullWant(nullWant),
)

View File

@@ -164,6 +164,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.WithMyToolId3NameAliceWant(invokeParamWant),
tests.WithMyArrayToolWant(invokeParamWant),
tests.WithMyToolById4Want(toolInvokeMyToolById4Want),
)
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolId3NameAliceWant))

View File

@@ -257,10 +257,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
configs := &InvokeTestConfig{
myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
myToolById4Want: "[{\"id\":4,\"name\":null}]",
myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
nullWant: "null",
supportOptionalNullParam: true,
supportArrayParam: true,
supportClientAuth: false,
supportSelect1Want: true,
supportSelect1Auth: true,
}
// Apply provided options
@@ -294,7 +297,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
{
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
enabled: true,
enabled: configs.supportSelect1Want,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: select1Want,
@@ -351,13 +354,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
enabled: configs.supportArrayParam,
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
wantBody: configs.myToolId3NameAliceWant,
wantBody: configs.myArrayToolWant,
wantStatusCode: http.StatusOK,
},
{
name: "Invoke my-auth-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
enabled: configs.supportSelect1Auth,
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantBody: "[{\"name\":\"Alice\"}]",
@@ -366,7 +369,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
{
name: "Invoke my-auth-tool with invalid auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
enabled: true,
enabled: configs.supportSelect1Auth,
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
requestBody: bytes.NewBuffer([]byte(`{}`)),
wantStatusCode: http.StatusUnauthorized,
@@ -382,7 +385,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
{
name: "Invoke my-auth-required-tool with auth token",
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
enabled: true,
enabled: configs.supportSelect1Auth,
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
@@ -491,6 +494,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
ddlWant: "null",
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
selectEmptyWant: "null",
insert1Want: "null",
@@ -512,6 +516,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
// Test tool invoke endpoint
invokeTcs := []struct {
name string
enabled bool
ddl bool
insert bool
api string
@@ -573,6 +578,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
},
{
name: "invoke select-fields-templateParams-tool",
enabled: configs.supportSelectFields,
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))),
@@ -584,7 +590,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
requestHeader: map[string]string{},
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))),
want: configs.selectId1Want,
want: configs.selectNameWant,
isErr: false,
},
{
@@ -599,6 +605,9 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
if !tc.enabled {
return
}
// if test case is DDL and source support ddl test cases
ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl)
// if test case is insert statement and source support insert test cases
@@ -834,6 +843,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
configs := &MCPTestConfig{
myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
supportClientAuth: false,
supportSelect1Auth: true,
}
// Apply provided options
@@ -947,7 +957,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
{
name: "MCP Invoke my-auth-required-tool",
api: "http://127.0.0.1:5000/mcp",
enabled: true,
enabled: configs.supportSelect1Auth,
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",

View File

@@ -107,6 +107,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
tests.RunToolGetTest(t)
tests.RunToolInvokeTest(t, select1Want,
tests.WithMyToolId3NameAliceWant(invokeParamWant),
tests.WithMyArrayToolWant(invokeParamWant),
tests.WithMyToolById4Want(invokeIdNullWant),
tests.WithNullWant(nullWant),
)