From 6e420534ee894da4a8d226acb6cdb63d0d5d9ce5 Mon Sep 17 00:00:00 2001 From: Pranava B Date: Wed, 24 Sep 2025 02:18:43 +0530 Subject: [PATCH] feat(cassandra): add Cassandra Source and Tool (#1012) [Cassandra](https://cassandra.apache.org/_/cassandra-basics.html) is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition. Cassandra go driver link - https://pkg.go.dev/github.com/apache/cassandra-gocql-driver/v2 This PR - adds a new source for cassandra - adds a new tool _cassandra-cql_ with support for executing predefined parameterized CQL queries on cassandra - adds unit and integration tests for the tool and the source - adds documentation for the cassandra source and cassandra-cql tool --------- Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Co-authored-by: duwenxin --- .ci/integration.cloudbuild.yaml | 26 ++ cmd/root.go | 2 + docs/en/resources/sources/cassandra.md | 57 ++++ docs/en/resources/tools/cassandra/_index.md | 7 + .../tools/cassandra/cassandra-cql.md | 96 ++++++ go.mod | 3 + go.sum | 10 + internal/sources/cassandra/cassandra.go | 134 +++++++++ internal/sources/cassandra/cassandra_test.go | 158 ++++++++++ .../cassandra/cassandracql/cassandracql.go | 182 +++++++++++ .../cassandracql/cassandracql_test.go | 171 +++++++++++ tests/cassandra/cassandra_integration_test.go | 284 ++++++++++++++++++ tests/mongodb/mongodb_integration_test.go | 1 + tests/option.go | 57 +++- tests/redis/redis_test.go | 1 + tests/spanner/spanner_integration_test.go | 1 + tests/tool.go | 24 +- tests/valkey/valkey_test.go | 1 + 18 files changed, 1206 insertions(+), 9 deletions(-) create mode 100644 docs/en/resources/sources/cassandra.md create mode 100644 docs/en/resources/tools/cassandra/_index.md create mode 100644 docs/en/resources/tools/cassandra/cassandra-cql.md create mode 100644 internal/sources/cassandra/cassandra.go create mode 100644 internal/sources/cassandra/cassandra_test.go create mode 100644 internal/tools/cassandra/cassandracql/cassandracql.go create mode 100644 internal/tools/cassandra/cassandracql/cassandracql_test.go create mode 100644 tests/cassandra/cassandra_integration_test.go diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index 16276e40c9..3d456fe77d 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -662,6 +662,26 @@ steps: - | ./yugabytedb.test -test.v + + - id: "cassandra" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["CLIENT_ID", "CASSANDRA_USER", "CASSANDRA_PASS", "CASSANDRA_HOST"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "Cassandra" \ + cassandra \ + cassandra + availableSecrets: secretManager: - versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest @@ -746,6 +766,12 @@ availableSecrets: env: YUGABYTEDB_USER - versionName: projects/$PROJECT_ID/secrets/yugabytedb_pass/versions/latest env: YUGABYTEDB_PASS + - versionName: projects/$PROJECT_ID/secrets/cassandra_user/versions/latest + env: CASSANDRA_USER + - versionName: projects/$PROJECT_ID/secrets/cassandra_pass/versions/latest + env: CASSANDRA_PASS + - versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest + env: CASSANDRA_HOST options: logging: CLOUD_LOGGING_ONLY diff --git a/cmd/root.go b/cmd/root.go index e38f8432ae..7d56e1a115 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -64,6 +64,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog" _ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql" _ "github.com/googleapis/genai-toolbox/internal/tools/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases" _ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables" @@ -159,6 +160,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg" _ "github.com/googleapis/genai-toolbox/internal/sources/bigquery" _ "github.com/googleapis/genai-toolbox/internal/sources/bigtable" + _ "github.com/googleapis/genai-toolbox/internal/sources/cassandra" _ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin" diff --git a/docs/en/resources/sources/cassandra.md b/docs/en/resources/sources/cassandra.md new file mode 100644 index 0000000000..241eb08550 --- /dev/null +++ b/docs/en/resources/sources/cassandra.md @@ -0,0 +1,57 @@ +--- +title: "Cassandra" +type: docs +weight: 1 +description: > + Cassandra is a NoSQL distributed database known for its horizontal scalability, distributed architecture, and flexible schema definition. +--- + +## About + +[Cassandra][cassandra-docs] is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition. + +[cassandra-docs]: https://cassandra.apache.org/ + +## Available Tools + +- [`cassandra-cql`](../tools/cassandra/cassandra-cql.md) + Run parameterized CQL queries in Cassandra. + + +## Example + +```yaml +sources: + my-cassandra-source: + kind: cassandra + hosts: + - 127.0.0.1 + keyspace: my_keyspace + protoVersion: 4 + username: ${USER_NAME} + password: ${PASSWORD} + caPath: /path/to/ca.crt # Optional: path to CA certificate + certPath: /path/to/client.crt # Optional: path to client certificate + keyPath: /path/to/client.key # Optional: path to client key + enableHostVerification: true # Optional: enable host verification +``` + +{{< notice tip >}} +Use environment variable replacement with the format ${ENV_NAME} +instead of hardcoding your secrets into the configuration file. +{{< /notice >}} + +## Reference + +| **field** | **type** | **required** | **description** | +|------------------------|:---------:|:------------:|-------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "cassandra". | +| hosts | string[] | true | List of IP addresses to connect to (e.g., ["192.168.1.1:9042", "192.168.1.2:9042","192.168.1.3:9042"]). The default port is 9042 if not specified. | +| keyspace | string | true | Name of the Cassandra keyspace to connect to (e.g., "my_keyspace"). | +| protoVersion | integer | false | Protocol version for the Cassandra connection (e.g., 4). | +| username | string | false | Name of the Cassandra user to connect as (e.g., "my-cassandra-user"). | +| password | string | false | Password of the Cassandra user (e.g., "my-password"). | +| caPath | string | false | Path to the CA certificate for SSL/TLS (e.g., "/path/to/ca.crt"). | +| certPath | string | false | Path to the client certificate for SSL/TLS (e.g., "/path/to/client.crt"). | +| keyPath | string | false | Path to the client key for SSL/TLS (e.g., "/path/to/client.key"). | +| enableHostVerification | boolean | false | Enable host verification for SSL/TLS (e.g., true). By default, host verification is disabled. | diff --git a/docs/en/resources/tools/cassandra/_index.md b/docs/en/resources/tools/cassandra/_index.md new file mode 100644 index 0000000000..3e1e07fd25 --- /dev/null +++ b/docs/en/resources/tools/cassandra/_index.md @@ -0,0 +1,7 @@ +--- +title: "Cassandra" +type: docs +weight: 1 +description: > + Tools that work with Cassandra Sources. +--- \ No newline at end of file diff --git a/docs/en/resources/tools/cassandra/cassandra-cql.md b/docs/en/resources/tools/cassandra/cassandra-cql.md new file mode 100644 index 0000000000..5b083d9b10 --- /dev/null +++ b/docs/en/resources/tools/cassandra/cassandra-cql.md @@ -0,0 +1,96 @@ +--- +title: "cassandra-cql" +type: docs +weight: 1 +description: > + A "cassandra-cql" tool executes a pre-defined CQL statement against a Cassandra + database. +aliases: +- /resources/tools/cassandra-cql +--- + +## About + +A `cassandra-cql` tool executes a pre-defined CQL statement against a Cassandra +database. It's compatible with any of the following sources: + +- [cassandra](../sources/cassandra.md) + +The specified CQL statement is executed as a [prepared statement][cassandra-prepare], +and expects parameters in the CQL query to be in the form of placeholders `?`. + +[cassandra-prepare]: https://docs.datastax.com/en/developer/go-driver/4.8/cql-prepared-statements/ + +## Example + +> **Note:** This tool uses parameterized queries to prevent CQL injections. +> Query parameters can be used as substitutes for arbitrary expressions. +> Parameters cannot be used as substitutes for keyspaces, table names, column names, +> or other parts of the query. + +```yaml +tools: + search_users_by_email: + kind: cassandra-cql + source: my-cassandra-cluster + statement: | + SELECT user_id, email, first_name, last_name, created_at + FROM users + WHERE email = ? + description: | + Use this tool to retrieve specific user information by their email address. + Takes an email address and returns user details including user ID, email, + first name, last name, and account creation timestamp. + Do NOT use this tool with a user ID or other identifiers. + Example: + {{ + "email": "user@example.com", + }} + parameters: + - name: email + type: string + description: User's email address +``` + +### Example with Template Parameters + +> **Note:** This tool allows direct modifications to the CQL statement, +> including keyspaces, table names, and column names. **This makes it more +> vulnerable to CQL injections**. Using basic parameters only (see above) is +> recommended for performance and safety reasons. For more details, please check +> [templateParameters](../#template-parameters). + +```yaml +tools: + list_keyspace_table: + kind: cassandra-cql + source: my-cassandra-cluster + statement: | + SELECT * FROM {{.keyspace}}.{{.tableName}}; + description: | + Use this tool to list all information from a specific table in a keyspace. + Example: + {{ + "keyspace": "my_keyspace", + "tableName": "users", + }} + templateParameters: + - name: keyspace + type: string + description: Keyspace containing the table + - name: tableName + type: string + description: Table to select from +``` + +## Reference + +| **field** | **type** | **required** | **description** | +|--------------------|:------------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------------------------------------------------| +| kind | string | true | Must be "cassandra-cql". | +| source | string | true | Name of the source the CQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | +| statement | string | true | CQL statement to execute. | +| authRequired | []string | false | List of authentication requirements for the source. | +| parameters | [parameters](../#specifying-parameters) | false | List of [parameters](../#specifying-parameters) that will be inserted into the CQL statement. | +| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the CQL statement before executing prepared statement. | diff --git a/go.mod b/go.mod index bc3f55978b..e893ea5f6b 100644 --- a/go.mod +++ b/go.mod @@ -26,6 +26,7 @@ require ( github.com/go-playground/validator/v10 v10.27.0 github.com/go-sql-driver/mysql v1.9.3 github.com/goccy/go-yaml v1.18.0 + github.com/gocql/gocql v1.7.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.6 @@ -115,6 +116,7 @@ require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -176,6 +178,7 @@ require ( google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/grpc v1.75.0 // indirect google.golang.org/protobuf v1.36.8 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.66.3 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/go.sum b/go.sum index daa9e6194c..13923a9d97 100644 --- a/go.sum +++ b/go.sum @@ -737,6 +737,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4p github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -897,6 +901,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus= +github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= @@ -1047,6 +1053,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVT github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= @@ -2042,6 +2050,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/sources/cassandra/cassandra.go b/internal/sources/cassandra/cassandra.go new file mode 100644 index 0000000000..29b8d2e23a --- /dev/null +++ b/internal/sources/cassandra/cassandra.go @@ -0,0 +1,134 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cassandra + +import ( + "context" + "fmt" + + "github.com/goccy/go-yaml" + "github.com/gocql/gocql" + "github.com/googleapis/genai-toolbox/internal/sources" + "go.opentelemetry.io/otel/trace" +) + +const SourceKind string = "cassandra" + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Hosts []string `yaml:"hosts" validate:"required"` + Keyspace string `yaml:"keyspace"` + ProtoVersion int `yaml:"protoVersion"` + Username string `yaml:"username"` + Password string `yaml:"password"` + CAPath string `yaml:"caPath"` + CertPath string `yaml:"certPath"` + KeyPath string `yaml:"keyPath"` + EnableHostVerification bool `yaml:"enableHostVerification"` +} + +// Initialize implements sources.SourceConfig. +func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + session, err := initCassandraSession(ctx, tracer, c) + if err != nil { + return nil, fmt.Errorf("unable to create session: %v", err) + } + s := &Source{ + Name: c.Name, + Kind: SourceKind, + Session: session, + } + return s, nil +} + +// SourceConfigKind implements sources.SourceConfig. +func (c Config) SourceConfigKind() string { + return SourceKind +} + +var _ sources.SourceConfig = Config{} + +type Source struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + Session *gocql.Session +} + +// CassandraSession implements cassandra.compatibleSource. +func (s *Source) CassandraSession() *gocql.Session { + return s.Session +} + +// SourceKind implements sources.Source. +func (s Source) SourceKind() string { + return SourceKind +} + +var _ sources.Source = &Source{} + +func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) { + //nolint:all // Reassigned ctx + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name) + defer span.End() + + // Validate authentication configuration + if c.Password != "" && c.Username == "" { + return nil, fmt.Errorf("invalid Cassandra configuration: password provided without a username") + } + + cluster := gocql.NewCluster(c.Hosts...) + cluster.ProtoVersion = c.ProtoVersion + cluster.Keyspace = c.Keyspace + + // Configure authentication if username is provided + if c.Username != "" { + cluster.Authenticator = gocql.PasswordAuthenticator{ + Username: c.Username, + Password: c.Password, + } + } + + // Configure SSL options if any are specified + if c.CAPath != "" || c.CertPath != "" || c.KeyPath != "" || c.EnableHostVerification { + cluster.SslOpts = &gocql.SslOptions{ + CaPath: c.CAPath, + CertPath: c.CertPath, + KeyPath: c.KeyPath, + EnableHostVerification: c.EnableHostVerification, + } + } + + // Create session + session, err := cluster.CreateSession() + if err != nil { + return nil, fmt.Errorf("failed to create Cassandra session: %w", err) + } + return session, nil +} diff --git a/internal/sources/cassandra/cassandra_test.go b/internal/sources/cassandra/cassandra_test.go new file mode 100644 index 0000000000..de112c9eb6 --- /dev/null +++ b/internal/sources/cassandra/cassandra_test.go @@ -0,0 +1,158 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cassandra_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/sources/cassandra" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlCassandra(t *testing.T) { + tcs := []struct { + desc string + in string + want server.SourceConfigs + }{ + { + desc: "basic example (without optional fields)", + in: ` + sources: + my-cassandra-instance: + kind: cassandra + hosts: + - "my-host1" + - "my-host2" + `, + want: server.SourceConfigs{ + "my-cassandra-instance": cassandra.Config{ + Name: "my-cassandra-instance", + Kind: cassandra.SourceKind, + Hosts: []string{"my-host1", "my-host2"}, + Username: "", + Password: "", + ProtoVersion: 0, + CAPath: "", + CertPath: "", + KeyPath: "", + Keyspace: "", + EnableHostVerification: false, + }, + }, + }, + { + desc: "with optional fields", + in: ` + sources: + my-cassandra-instance: + kind: cassandra + hosts: + - "my-host1" + - "my-host2" + username: "user" + password: "pass" + keyspace: "example_keyspace" + protoVersion: 4 + caPath: "path/to/ca.crt" + certPath: "path/to/cert" + keyPath: "path/to/key" + enableHostVerification: true + `, + want: server.SourceConfigs{ + "my-cassandra-instance": cassandra.Config{ + Name: "my-cassandra-instance", + Kind: cassandra.SourceKind, + Hosts: []string{"my-host1", "my-host2"}, + Username: "user", + Password: "pass", + Keyspace: "example_keyspace", + ProtoVersion: 4, + CAPath: "path/to/ca.crt", + CertPath: "path/to/cert", + KeyPath: "path/to/key", + EnableHostVerification: true, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Sources) { + t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources) + } + }) + } + +} + +func TestFailParseFromYaml(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "extra field", + in: ` + sources: + my-cassandra-instance: + kind: cassandra + host: + - "my-host" + foo: bar + `, + err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host:\n 3 | - my-host\n 4 | kind: cassandra", + }, + { + desc: "missing required field", + in: ` + sources: + my-cassandra-instance: + kind: cassandra + `, + err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag", + }, + } + + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Sources server.SourceConfigs `yaml:"sources"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + errStr := err.Error() + if errStr != tc.err { + t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err) + } + }) + } + +} diff --git a/internal/tools/cassandra/cassandracql/cassandracql.go b/internal/tools/cassandra/cassandracql/cassandracql.go new file mode 100644 index 0000000000..7df80e0ccf --- /dev/null +++ b/internal/tools/cassandra/cassandracql/cassandracql.go @@ -0,0 +1,182 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cassandracql + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/gocql/gocql" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cassandra" + "github.com/googleapis/genai-toolbox/internal/tools" +) + +const kind string = "cassandra-cql" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + CassandraSession() *gocql.Session +} + +var _ compatibleSource = &cassandra.Source{} + +var compatibleSources = [...]string{cassandra.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Parameters tools.Parameters `yaml:"parameters"` + TemplateParameters tools.Parameters `yaml:"templateParameters"` +} + +// Initialize implements tools.ToolConfig. +func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // verify source exists + rawS, ok := srcs[c.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", c.Source) + } + + // verify the source is compatible + s, ok := rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters, paramManifest, paramMcpManifest, err := tools.ProcessParameters(c.TemplateParameters, c.Parameters) + if err != nil { + return nil, err + } + + mcpManifest := tools.McpManifest{ + Name: c.Name, + Description: c.Description, + InputSchema: paramMcpManifest, + } + + t := Tool{ + Name: c.Name, + Kind: kind, + Parameters: c.Parameters, + TemplateParameters: c.TemplateParameters, + AllParams: allParameters, + Statement: c.Statement, + AuthRequired: c.AuthRequired, + Session: s.CassandraSession(), + manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// ToolConfigKind implements tools.ToolConfig. +func (c Config) ToolConfigKind() string { + return kind +} + +var _ tools.ToolConfig = Config{} + +type Tool struct { + Name string `yaml:"name"` + Kind string `yaml:"kind"` + AuthRequired []string `yaml:"authRequired"` + Parameters tools.Parameters `yaml:"parameters"` + TemplateParameters tools.Parameters `yaml:"templateParameters"` + AllParams tools.Parameters `yaml:"allParams"` + + Session *gocql.Session + Statement string + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +// RequiresClientAuthorization implements tools.Tool. +func (t Tool) RequiresClientAuthorization() bool { + return false +} + +// Authorized implements tools.Tool. +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +// Invoke implements tools.Tool. +func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) { + paramsMap := params.AsMap() + newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract template params %w", err) + } + + newParams, err := tools.GetParams(t.Parameters, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + iter := t.Session.Query(newStatement, sliceParams...).WithContext(ctx).Iter() + + // Create a slice to store the out + var out []map[string]interface{} + + // Scan results into a map and append to the slice + for { + row := make(map[string]interface{}) // Create a new map for each row + if !iter.MapScan(row) { + break // No more rows + } + out = append(out, row) + } + + if err := iter.Close(); err != nil { + return nil, fmt.Errorf("unable to parse rows: %w", err) + } + return out, nil +} + +// Manifest implements tools.Tool. +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +// McpManifest implements tools.Tool. +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +// ParseParams implements tools.Tool. +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) { + return tools.ParseParams(t.AllParams, data, claims) +} + +var _ tools.Tool = Tool{} diff --git a/internal/tools/cassandra/cassandracql/cassandracql_test.go b/internal/tools/cassandra/cassandracql/cassandracql_test.go new file mode 100644 index 0000000000..cf9a729f83 --- /dev/null +++ b/internal/tools/cassandra/cassandracql/cassandracql_test.go @@ -0,0 +1,171 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cassandracql_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql" +) + +func TestParseFromYamlCassandra(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + tools: + example_tool: + kind: cassandra-cql + source: my-cassandra-instance + description: some description + statement: | + SELECT * FROM CQL_STATEMENT; + authRequired: + - my-google-auth-service + - other-auth-service + parameters: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + `, + want: server.ToolConfigs{ + "example_tool": cassandracql.Config{ + Name: "example_tool", + Kind: "cassandra-cql", + Source: "my-cassandra-instance", + Description: "some description", + Statement: "SELECT * FROM CQL_STATEMENT;\n", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + Parameters: []tools.Parameter{ + tools.NewStringParameterWithAuth("country", "some description", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + }, + }, + }, + { + desc: "with template parameters", + in: ` + tools: + example_tool: + kind: cassandra-cql + source: my-cassandra-instance + description: some description + statement: | + SELECT * FROM CQL_STATEMENT; + authRequired: + - my-google-auth-service + - other-auth-service + parameters: + - name: country + type: string + description: some description + authServices: + - name: my-google-auth-service + field: user_id + - name: other-auth-service + field: user_id + templateParameters: + - name: tableName + type: string + description: some description. + - name: fieldArray + type: array + description: The columns to return for the query. + items: + name: column + type: string + description: A column name that will be returned from the query. + `, + want: server.ToolConfigs{ + "example_tool": cassandracql.Config{ + Name: "example_tool", + Kind: "cassandra-cql", + Source: "my-cassandra-instance", + Description: "some description", + Statement: "SELECT * FROM CQL_STATEMENT;\n", + AuthRequired: []string{"my-google-auth-service", "other-auth-service"}, + Parameters: []tools.Parameter{ + tools.NewStringParameterWithAuth("country", "some description", + []tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"}, + {Name: "other-auth-service", Field: "user_id"}}), + }, + TemplateParameters: []tools.Parameter{ + tools.NewStringParameter("tableName", "some description."), + tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")), + }, + }, + }, + }, + { + desc: "without optional fields", + in: ` + tools: + example_tool: + kind: cassandra-cql + source: my-cassandra-instance + description: some description + statement: | + SELECT * FROM CQL_STATEMENT; + `, + want: server.ToolConfigs{ + "example_tool": cassandracql.Config{ + Name: "example_tool", + Kind: "cassandra-cql", + Source: "my-cassandra-instance", + Description: "some description", + Statement: "SELECT * FROM CQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: nil, + TemplateParameters: nil, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Tools server.ToolConfigs `yaml:"tools"` + }{} + // Parse contents + err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got.Tools); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/tests/cassandra/cassandra_integration_test.go b/tests/cassandra/cassandra_integration_test.go new file mode 100644 index 0000000000..85eee1f89d --- /dev/null +++ b/tests/cassandra/cassandra_integration_test.go @@ -0,0 +1,284 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cassandra + +import ( + "context" + "fmt" + "log" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/gocql/gocql" + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" +) + +var ( + CassandraSourceKind = "cassandra" + CassandraToolKind = "cassandra-cql" + Hosts = os.Getenv("CASSANDRA_HOST") + Keyspace = "example_keyspace" + Username = os.Getenv("CASSANDRA_USER") + Password = os.Getenv("CASSANDRA_PASS") +) + +func getCassandraVars(t *testing.T) map[string]any { + switch "" { + case Hosts: + t.Fatal("'Hosts' not set") + case Username: + t.Fatal("'Username' not set") + case Password: + t.Fatal("'Password' not set") + } + return map[string]any{ + "kind": CassandraSourceKind, + "hosts": strings.Split(Hosts, ","), + "keyspace": Keyspace, + "username": Username, + "password": Password, + } +} + +func initCassandraSession() (*gocql.Session, error) { + hostStrings := strings.Split(Hosts, ",") + + var hosts []string + for _, h := range hostStrings { + trimmedHost := strings.TrimSpace(h) + if trimmedHost != "" { + hosts = append(hosts, trimmedHost) + } + } + if len(hosts) == 0 { + return nil, fmt.Errorf("no valid hosts found in CASSANDRA_HOSTS env var") + } + // Configure cluster connection + cluster := gocql.NewCluster(hosts...) + cluster.Consistency = gocql.Quorum + cluster.ProtoVersion = 4 + cluster.DisableInitialHostLookup = true + cluster.ConnectTimeout = 10 * time.Second + cluster.NumConns = 2 + cluster.Authenticator = gocql.PasswordAuthenticator{ + Username: Username, + Password: Password, + } + cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{ + NumRetries: 3, + Min: 200 * time.Millisecond, + Max: 2 * time.Second, + } + + // Create session + session, err := cluster.CreateSession() + if err != nil { + return nil, fmt.Errorf("Failed to create session: %v", err) + } + + // Create keyspace + err = session.Query(fmt.Sprintf(` + CREATE KEYSPACE IF NOT EXISTS %s + WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1} + `, Keyspace)).Exec() + if err != nil { + return nil, fmt.Errorf("Failed to create keyspace: %v", err) + } + + return session, nil +} + +func initTable(tableName string, session *gocql.Session) error { + + // Create table with additional columns + err := session.Query(fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s.%s ( + id int PRIMARY KEY, + name text, + email text, + age int, + is_active boolean, + created_at timestamp + ) + `, Keyspace, tableName)).Exec() + if err != nil { + return fmt.Errorf("Failed to create table: %v", err) + } + + // Use fixed timestamps for reproducibility + fixedTime, _ := time.Parse(time.RFC3339, "2025-07-25T12:00:00Z") + dayAgo := fixedTime.Add(-24 * time.Hour) + twelveHoursAgo := fixedTime.Add(-12 * time.Hour) + + // Insert minimal diverse data with fixed time.Time for timestamps + err = session.Query(fmt.Sprintf(` + INSERT INTO %s.%s (id, name,email, age, is_active, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), + 3, "Alice", tests.ServiceAccountEmail, 25, true, dayAgo, + ).Exec() + if err != nil { + return fmt.Errorf("Failed to insert user: %v", err) + } + err = session.Query(fmt.Sprintf(` + INSERT INTO %s.%s (id, name,email, age, is_active, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), + 2, "Alex", "janedoe@gmail.com", 30, false, twelveHoursAgo, + ).Exec() + if err != nil { + return fmt.Errorf("Failed to insert user: %v", err) + } + err = session.Query(fmt.Sprintf(` + INSERT INTO %s.%s (id, name,email, age, is_active, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), + 1, "Sid", "sid@gmail.com", 10, true, fixedTime, + ).Exec() + if err != nil { + return fmt.Errorf("Failed to insert user: %v", err) + } + err = session.Query(fmt.Sprintf(` + INSERT INTO %s.%s (id, name,email, age, is_active, created_at) + VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName), + 4, nil, "a@gmail.com", 40, false, fixedTime, + ).Exec() + if err != nil { + return fmt.Errorf("Failed to insert user: %v", err) + } + return nil +} + +func dropTable(session *gocql.Session, tableName string) { + err := session.Query(fmt.Sprintf("drop table %s.%s", Keyspace, tableName)).Exec() + if err != nil { + log.Printf("Failed to drop table %s: %v", tableName, err) + } +} + +func TestCassandra(t *testing.T) { + session, err := initCassandraSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + sourceConfig := getCassandraVars(t) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + var args []string + paramTableName := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + err = initTable(paramTableName, session) + if err != nil { + t.Fatal(err) + } + defer dropTable(session, paramTableName) + + err = initTable(tableNameAuth, session) + if err != nil { + t.Fatal(err) + } + defer dropTable(session, tableNameAuth) + + err = initTable(tableNameTemplateParam, session) + if err != nil { + t.Fatal(err) + } + defer dropTable(session, tableNameTemplateParam) + + paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt := createParamToolInfo(paramTableName) + _, _, authToolStmt := getCassandraAuthToolInfo(tableNameAuth) + toolsFile := tests.GetToolsConfig(sourceConfig, CassandraToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + + tmplSelectCombined, tmplSelectFilterCombined := getCassandraTmplToolInfo() + tmpSelectAll := "SELECT * FROM {{.tableName}} where id = 1" + + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CassandraToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmpSelectAll) + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, mcpSelect1Want, mcpMyToolIdWant := getCassandraWants() + selectAllWant, selectIdWant, selectNameWant := getCassandraTmplWants() + + tests.RunToolGetTest(t) + tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test(), + tests.DisableOptionalNullParamTest(), + tests.WithMyToolId3NameAliceWant(selectIdNameWant), + tests.WithMyToolById4Want(selectIdNullWant), + tests.WithMyArrayToolWant(selectArrayParamWant), + tests.DisableSelect1AuthTest()) + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, + tests.DisableSelectFilterTest(), + tests.WithSelectAllWant(selectAllWant), + tests.DisableDdlTest(), tests.DisableInsertTest(), tests.WithTmplSelectId1Want(selectIdWant), tests.WithTmplSelectNameWant(selectNameWant)) + + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, + tests.WithMcpMyToolId3NameAliceWant(mcpMyToolIdWant), + tests.DisableMcpSelect1AuthTest()) + +} + +func createParamToolInfo(tableName string) (string, string, string, string) { + toolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id = ? AND name = ? ALLOW FILTERING;", tableName) + idParamStatement := fmt.Sprintf("SELECT id,name FROM %s WHERE id = ?;", tableName) + nameParamStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE name = ? ALLOW FILTERING;", tableName) + arrayToolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id IN ? AND name IN ? ALLOW FILTERING;", tableName) + return toolStatement, idParamStatement, nameParamStatement, arrayToolStatement + +} + +func getCassandraAuthToolInfo(tableName string) (string, string, string) { + createStatement := fmt.Sprintf("CREATE TABLE %s (id UUID PRIMARY KEY, name TEXT, email TEXT);", tableName) + insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (uuid(), ?, ?), (uuid(), ?, ?);", tableName) + toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ? ALLOW FILTERING;", tableName) + return createStatement, insertStatement, toolStatement +} + +func getCassandraTmplToolInfo() (string, string) { + selectAllTemplateStmt := "SELECT age, id, name FROM {{.tableName}} where id = ?;" + selectByIdTemplateStmt := "SELECT id, name FROM {{.tableName}} WHERE name = ? ALLOW FILTERING;" + return selectAllTemplateStmt, selectByIdTemplateStmt +} + +func getCassandraWants() (string, string, string, string, string, string) { + selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]" + selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]" + selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]" + mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}" + mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}" + return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant +} + +func getCassandraTmplWants() (string, string, string) { + selectAllWant := "[{\"age\":10,\"created_at\":\"2025-07-25T12:00:00Z\",\"email\":\"sid@gmail.com\",\"id\":1,\"is_active\":true,\"name\":\"Sid\"}]" + selectIdWant := "[{\"age\":10,\"id\":1,\"name\":\"Sid\"}]" + selectNameWant := "[{\"id\":2,\"name\":\"Alex\"}]" + return selectAllWant, selectIdWant, selectNameWant +} diff --git a/tests/mongodb/mongodb_integration_test.go b/tests/mongodb/mongodb_integration_test.go index 84561e514d..16b13c14bf 100644 --- a/tests/mongodb/mongodb_integration_test.go +++ b/tests/mongodb/mongodb_integration_test.go @@ -110,6 +110,7 @@ func TestMongoDBToolEndpoints(t *testing.T) { tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolId3NameAliceWant(myToolId3NameAliceWant), + tests.WithMyArrayToolWant(myToolId3NameAliceWant), tests.WithMyToolById4Want(myToolById4Want), ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, select1Want, diff --git a/tests/option.go b/tests/option.go index 4aa29bacb5..f6ad8534a6 100644 --- a/tests/option.go +++ b/tests/option.go @@ -21,9 +21,12 @@ type InvokeTestConfig struct { myToolId3NameAliceWant string myToolById4Want string nullWant string + myArrayToolWant string + supportSelect1Want bool supportOptionalNullParam bool supportArrayParam bool supportClientAuth bool + supportSelect1Auth bool } type InvokeTestOption func(*InvokeTestConfig) @@ -36,6 +39,14 @@ func WithMyToolId3NameAliceWant(s string) InvokeTestOption { } } +// WithMyArrayToolWant represents the response value for my-array-tool. +// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyArrayToolWant("custom")) +func WithMyArrayToolWant(s string) InvokeTestOption { + return func(c *InvokeTestConfig) { + c.myArrayToolWant = s + } +} + // WithMyToolById4Want represents the response value for my-tool-by-id with id=4. // This response includes a null value column. // e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolById4Want("custom")) @@ -69,6 +80,22 @@ func DisableArrayTest() InvokeTestOption { } } +// DisableSelect1Test disables tests for sources that do not support SELECT 1 query. +// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test()) +func DisableSelect1Test() InvokeTestOption { + return func(c *InvokeTestConfig) { + c.supportSelect1Want = false + } +} + +// DisableSelect1AuthTest disables auth tests for sources that do not support SELECT 1 query. +// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1AuthTest()) +func DisableSelect1AuthTest() InvokeTestOption { + return func(c *InvokeTestConfig) { + c.supportSelect1Auth = false + } +} + // EnableClientAuthTest runs the client authorization tests. // Only enable it if your source supports the `useClientOAuth` configuration. // Currently, this should only be used with the BigQuery tests. @@ -84,6 +111,7 @@ func EnableClientAuthTest() InvokeTestOption { type MCPTestConfig struct { myToolId3NameAliceWant string supportClientAuth bool + supportSelect1Auth bool } type McpTestOption func(*MCPTestConfig) @@ -105,6 +133,13 @@ func EnableMcpClientAuthTest() McpTestOption { } } +// DisableMcpSelect1AuthTest disables the auth tool tests which use select 1. +func DisableMcpSelect1AuthTest() McpTestOption { + return func(c *MCPTestConfig) { + c.supportSelect1Auth = false + } +} + /* Configurations for RunExecuteSqlToolInvokeTest() */ // ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest() @@ -129,6 +164,7 @@ type TemplateParameterTestConfig struct { ddlWant string selectAllWant string selectId1Want string + selectNameWant string selectEmptyWant string insert1Want string @@ -136,8 +172,9 @@ type TemplateParameterTestConfig struct { nameColFilter string createColArray string - supportDdl bool - supportInsert bool + supportDdl bool + supportInsert bool + supportSelectFields bool } type TemplateParamOption func(*TemplateParameterTestConfig) @@ -166,6 +203,14 @@ func WithTmplSelectId1Want(s string) TemplateParamOption { } } +// WithTmplSelectNameWant represents the response value of select-filter-templateParams-combined-tool with name. +// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithTmplSelectNameWant("custom")) +func WithTmplSelectNameWant(s string) TemplateParamOption { + return func(c *TemplateParameterTestConfig) { + c.selectNameWant = s + } +} + // WithSelectEmptyWant represents the response value of select-templateParams-combined-tool with no results. // e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithSelectEmptyWant("custom")) func WithSelectEmptyWant(s string) TemplateParamOption { @@ -221,3 +266,11 @@ func DisableInsertTest() TemplateParamOption { c.supportInsert = false } } + +// DisableInsertTest disables tests of select-fields-templateParams-tool test. +// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.DisableSelectFilterTest()) +func DisableSelectFilterTest() TemplateParamOption { + return func(c *TemplateParameterTestConfig) { + c.supportSelectFields = false + } +} diff --git a/tests/redis/redis_test.go b/tests/redis/redis_test.go index 9b8c6f0c78..6ee611da51 100644 --- a/tests/redis/redis_test.go +++ b/tests/redis/redis_test.go @@ -104,6 +104,7 @@ func TestRedisToolEndpoints(t *testing.T) { tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolId3NameAliceWant(invokeParamWant), + tests.WithMyArrayToolWant(invokeParamWant), tests.WithMyToolById4Want(invokeIdNullWant), tests.WithNullWant(nullWant), ) diff --git a/tests/spanner/spanner_integration_test.go b/tests/spanner/spanner_integration_test.go index 5944787d83..499d0d7f75 100644 --- a/tests/spanner/spanner_integration_test.go +++ b/tests/spanner/spanner_integration_test.go @@ -164,6 +164,7 @@ func TestSpannerToolEndpoints(t *testing.T) { tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolId3NameAliceWant(invokeParamWant), + tests.WithMyArrayToolWant(invokeParamWant), tests.WithMyToolById4Want(toolInvokeMyToolById4Want), ) tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolId3NameAliceWant)) diff --git a/tests/tool.go b/tests/tool.go index e12b2809a1..e54400b035 100644 --- a/tests/tool.go +++ b/tests/tool.go @@ -257,10 +257,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp configs := &InvokeTestConfig{ myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", myToolById4Want: "[{\"id\":4,\"name\":null}]", + myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]", nullWant: "null", supportOptionalNullParam: true, supportArrayParam: true, supportClientAuth: false, + supportSelect1Want: true, + supportSelect1Auth: true, } // Apply provided options @@ -294,7 +297,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp { name: "invoke my-simple-tool", api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke", - enabled: true, + enabled: configs.supportSelect1Want, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: select1Want, @@ -351,13 +354,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp enabled: configs.supportArrayParam, requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)), - wantBody: configs.myToolId3NameAliceWant, + wantBody: configs.myArrayToolWant, wantStatusCode: http.StatusOK, }, { name: "Invoke my-auth-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", - enabled: true, + enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantBody: "[{\"name\":\"Alice\"}]", @@ -366,7 +369,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp { name: "Invoke my-auth-tool with invalid auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke", - enabled: true, + enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"}, requestBody: bytes.NewBuffer([]byte(`{}`)), wantStatusCode: http.StatusUnauthorized, @@ -382,7 +385,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp { name: "Invoke my-auth-required-tool with auth token", api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke", - enabled: true, + enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: bytes.NewBuffer([]byte(`{}`)), @@ -491,6 +494,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options ddlWant: "null", selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]", selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", + selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]", selectEmptyWant: "null", insert1Want: "null", @@ -512,6 +516,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options // Test tool invoke endpoint invokeTcs := []struct { name string + enabled bool ddl bool insert bool api string @@ -573,6 +578,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options }, { name: "invoke select-fields-templateParams-tool", + enabled: configs.supportSelectFields, api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))), @@ -584,7 +590,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke", requestHeader: map[string]string{}, requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))), - want: configs.selectId1Want, + want: configs.selectNameWant, isErr: false, }, { @@ -599,6 +605,9 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options } for _, tc := range invokeTcs { t.Run(tc.name, func(t *testing.T) { + if !tc.enabled { + return + } // if test case is DDL and source support ddl test cases ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl) // if test case is insert statement and source support insert test cases @@ -834,6 +843,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti configs := &MCPTestConfig{ myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`, supportClientAuth: false, + supportSelect1Auth: true, } // Apply provided options @@ -947,7 +957,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti { name: "MCP Invoke my-auth-required-tool", api: "http://127.0.0.1:5000/mcp", - enabled: true, + enabled: configs.supportSelect1Auth, requestHeader: map[string]string{"my-google-auth_token": idToken}, requestBody: jsonrpc.JSONRPCRequest{ Jsonrpc: "2.0", diff --git a/tests/valkey/valkey_test.go b/tests/valkey/valkey_test.go index e3922e14ff..23f2b09df2 100644 --- a/tests/valkey/valkey_test.go +++ b/tests/valkey/valkey_test.go @@ -107,6 +107,7 @@ func TestValkeyToolEndpoints(t *testing.T) { tests.RunToolGetTest(t) tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolId3NameAliceWant(invokeParamWant), + tests.WithMyArrayToolWant(invokeParamWant), tests.WithMyToolById4Want(invokeIdNullWant), tests.WithNullWant(nullWant), )