mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
feat(cassandra): add Cassandra Source and Tool (#1012)
[Cassandra](https://cassandra.apache.org/_/cassandra-basics.html) is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition. Cassandra go driver link - https://pkg.go.dev/github.com/apache/cassandra-gocql-driver/v2 This PR - adds a new source for cassandra - adds a new tool _cassandra-cql_ with support for executing predefined parameterized CQL queries on cassandra - adds unit and integration tests for the tool and the source - adds documentation for the cassandra source and cassandra-cql tool --------- Co-authored-by: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Co-authored-by: duwenxin <duwenxin@google.com>
This commit is contained in:
@@ -662,6 +662,26 @@ steps:
|
||||
- |
|
||||
./yugabytedb.test -test.v
|
||||
|
||||
|
||||
- id: "cassandra"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID", "CASSANDRA_USER", "CASSANDRA_PASS", "CASSANDRA_HOST"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cassandra" \
|
||||
cassandra \
|
||||
cassandra
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -746,6 +766,12 @@ availableSecrets:
|
||||
env: YUGABYTEDB_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/yugabytedb_pass/versions/latest
|
||||
env: YUGABYTEDB_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_user/versions/latest
|
||||
env: CASSANDRA_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_pass/versions/latest
|
||||
env: CASSANDRA_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest
|
||||
env: CASSANDRA_HOST
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
|
||||
@@ -64,6 +64,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||
@@ -159,6 +160,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
|
||||
57
docs/en/resources/sources/cassandra.md
Normal file
57
docs/en/resources/sources/cassandra.md
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: "Cassandra"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Cassandra is a NoSQL distributed database known for its horizontal scalability, distributed architecture, and flexible schema definition.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[Cassandra][cassandra-docs] is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition.
|
||||
|
||||
[cassandra-docs]: https://cassandra.apache.org/
|
||||
|
||||
## Available Tools
|
||||
|
||||
- [`cassandra-cql`](../tools/cassandra/cassandra-cql.md)
|
||||
Run parameterized CQL queries in Cassandra.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-cassandra-source:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- 127.0.0.1
|
||||
keyspace: my_keyspace
|
||||
protoVersion: 4
|
||||
username: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
caPath: /path/to/ca.crt # Optional: path to CA certificate
|
||||
certPath: /path/to/client.crt # Optional: path to client certificate
|
||||
keyPath: /path/to/client.key # Optional: path to client key
|
||||
enableHostVerification: true # Optional: enable host verification
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|------------------------|:---------:|:------------:|-------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cassandra". |
|
||||
| hosts | string[] | true | List of IP addresses to connect to (e.g., ["192.168.1.1:9042", "192.168.1.2:9042","192.168.1.3:9042"]). The default port is 9042 if not specified. |
|
||||
| keyspace | string | true | Name of the Cassandra keyspace to connect to (e.g., "my_keyspace"). |
|
||||
| protoVersion | integer | false | Protocol version for the Cassandra connection (e.g., 4). |
|
||||
| username | string | false | Name of the Cassandra user to connect as (e.g., "my-cassandra-user"). |
|
||||
| password | string | false | Password of the Cassandra user (e.g., "my-password"). |
|
||||
| caPath | string | false | Path to the CA certificate for SSL/TLS (e.g., "/path/to/ca.crt"). |
|
||||
| certPath | string | false | Path to the client certificate for SSL/TLS (e.g., "/path/to/client.crt"). |
|
||||
| keyPath | string | false | Path to the client key for SSL/TLS (e.g., "/path/to/client.key"). |
|
||||
| enableHostVerification | boolean | false | Enable host verification for SSL/TLS (e.g., true). By default, host verification is disabled. |
|
||||
7
docs/en/resources/tools/cassandra/_index.md
Normal file
7
docs/en/resources/tools/cassandra/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Cassandra"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Cassandra Sources.
|
||||
---
|
||||
96
docs/en/resources/tools/cassandra/cassandra-cql.md
Normal file
96
docs/en/resources/tools/cassandra/cassandra-cql.md
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
title: "cassandra-cql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cassandra-cql" tool executes a pre-defined CQL statement against a Cassandra
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/cassandra-cql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `cassandra-cql` tool executes a pre-defined CQL statement against a Cassandra
|
||||
database. It's compatible with any of the following sources:
|
||||
|
||||
- [cassandra](../sources/cassandra.md)
|
||||
|
||||
The specified CQL statement is executed as a [prepared statement][cassandra-prepare],
|
||||
and expects parameters in the CQL query to be in the form of placeholders `?`.
|
||||
|
||||
[cassandra-prepare]: https://docs.datastax.com/en/developer/go-driver/4.8/cql-prepared-statements/
|
||||
|
||||
## Example
|
||||
|
||||
> **Note:** This tool uses parameterized queries to prevent CQL injections.
|
||||
> Query parameters can be used as substitutes for arbitrary expressions.
|
||||
> Parameters cannot be used as substitutes for keyspaces, table names, column names,
|
||||
> or other parts of the query.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_users_by_email:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-cluster
|
||||
statement: |
|
||||
SELECT user_id, email, first_name, last_name, created_at
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
description: |
|
||||
Use this tool to retrieve specific user information by their email address.
|
||||
Takes an email address and returns user details including user ID, email,
|
||||
first name, last name, and account creation timestamp.
|
||||
Do NOT use this tool with a user ID or other identifiers.
|
||||
Example:
|
||||
{{
|
||||
"email": "user@example.com",
|
||||
}}
|
||||
parameters:
|
||||
- name: email
|
||||
type: string
|
||||
description: User's email address
|
||||
```
|
||||
|
||||
### Example with Template Parameters
|
||||
|
||||
> **Note:** This tool allows direct modifications to the CQL statement,
|
||||
> including keyspaces, table names, and column names. **This makes it more
|
||||
> vulnerable to CQL injections**. Using basic parameters only (see above) is
|
||||
> recommended for performance and safety reasons. For more details, please check
|
||||
> [templateParameters](../#template-parameters).
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_keyspace_table:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-cluster
|
||||
statement: |
|
||||
SELECT * FROM {{.keyspace}}.{{.tableName}};
|
||||
description: |
|
||||
Use this tool to list all information from a specific table in a keyspace.
|
||||
Example:
|
||||
{{
|
||||
"keyspace": "my_keyspace",
|
||||
"tableName": "users",
|
||||
}}
|
||||
templateParameters:
|
||||
- name: keyspace
|
||||
type: string
|
||||
description: Keyspace containing the table
|
||||
- name: tableName
|
||||
type: string
|
||||
description: Table to select from
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------------|:------------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cassandra-cql". |
|
||||
| source | string | true | Name of the source the CQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| statement | string | true | CQL statement to execute. |
|
||||
| authRequired | []string | false | List of authentication requirements for the source. |
|
||||
| parameters | [parameters](../#specifying-parameters) | false | List of [parameters](../#specifying-parameters) that will be inserted into the CQL statement. |
|
||||
| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the CQL statement before executing prepared statement. |
|
||||
3
go.mod
3
go.mod
@@ -26,6 +26,7 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.27.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -115,6 +116,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -176,6 +178,7 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -737,6 +737,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4p
|
||||
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
|
||||
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -897,6 +901,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -1047,6 +1053,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVT
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
|
||||
@@ -2042,6 +2050,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
||||
134
internal/sources/cassandra/cassandra.go
Normal file
134
internal/sources/cassandra/cassandra.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cassandra"
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Hosts []string `yaml:"hosts" validate:"required"`
|
||||
Keyspace string `yaml:"keyspace"`
|
||||
ProtoVersion int `yaml:"protoVersion"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
CAPath string `yaml:"caPath"`
|
||||
CertPath string `yaml:"certPath"`
|
||||
KeyPath string `yaml:"keyPath"`
|
||||
EnableHostVerification bool `yaml:"enableHostVerification"`
|
||||
}
|
||||
|
||||
// Initialize implements sources.SourceConfig.
|
||||
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
session, err := initCassandraSession(ctx, tracer, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create session: %v", err)
|
||||
}
|
||||
s := &Source{
|
||||
Name: c.Name,
|
||||
Kind: SourceKind,
|
||||
Session: session,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceConfigKind implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Session *gocql.Session
|
||||
}
|
||||
|
||||
// CassandraSession implements cassandra.compatibleSource.
|
||||
func (s *Source) CassandraSession() *gocql.Session {
|
||||
return s.Session
|
||||
}
|
||||
|
||||
// SourceKind implements sources.Source.
|
||||
func (s Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
|
||||
defer span.End()
|
||||
|
||||
// Validate authentication configuration
|
||||
if c.Password != "" && c.Username == "" {
|
||||
return nil, fmt.Errorf("invalid Cassandra configuration: password provided without a username")
|
||||
}
|
||||
|
||||
cluster := gocql.NewCluster(c.Hosts...)
|
||||
cluster.ProtoVersion = c.ProtoVersion
|
||||
cluster.Keyspace = c.Keyspace
|
||||
|
||||
// Configure authentication if username is provided
|
||||
if c.Username != "" {
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
}
|
||||
}
|
||||
|
||||
// Configure SSL options if any are specified
|
||||
if c.CAPath != "" || c.CertPath != "" || c.KeyPath != "" || c.EnableHostVerification {
|
||||
cluster.SslOpts = &gocql.SslOptions{
|
||||
CaPath: c.CAPath,
|
||||
CertPath: c.CertPath,
|
||||
KeyPath: c.KeyPath,
|
||||
EnableHostVerification: c.EnableHostVerification,
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Cassandra session: %w", err)
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
158
internal/sources/cassandra/cassandra_test.go
Normal file
158
internal/sources/cassandra/cassandra_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCassandra(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example (without optional fields)",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "",
|
||||
Password: "",
|
||||
ProtoVersion: 0,
|
||||
CAPath: "",
|
||||
CertPath: "",
|
||||
KeyPath: "",
|
||||
Keyspace: "",
|
||||
EnableHostVerification: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with optional fields",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
username: "user"
|
||||
password: "pass"
|
||||
keyspace: "example_keyspace"
|
||||
protoVersion: 4
|
||||
caPath: "path/to/ca.crt"
|
||||
certPath: "path/to/cert"
|
||||
keyPath: "path/to/key"
|
||||
enableHostVerification: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Keyspace: "example_keyspace",
|
||||
ProtoVersion: 4,
|
||||
CAPath: "path/to/ca.crt",
|
||||
CertPath: "path/to/cert",
|
||||
KeyPath: "path/to/key",
|
||||
EnableHostVerification: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
host:
|
||||
- "my-host"
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host:\n 3 | - my-host\n 4 | kind: cassandra",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
`,
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
182
internal/tools/cassandra/cassandracql/cassandracql.go
Normal file
182
internal/tools/cassandra/cassandracql/cassandracql.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandracql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "cassandra-cql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CassandraSession() *gocql.Session
|
||||
}
|
||||
|
||||
var _ compatibleSource = &cassandra.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cassandra.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest, err := tools.ProcessParameters(c.TemplateParameters, c.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: c.Name,
|
||||
Description: c.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: c.Name,
|
||||
Kind: kind,
|
||||
Parameters: c.Parameters,
|
||||
TemplateParameters: c.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: c.Statement,
|
||||
AuthRequired: c.AuthRequired,
|
||||
Session: s.CassandraSession(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Session *gocql.Session
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// RequiresClientAuthorization implements tools.Tool.
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
iter := t.Session.Query(newStatement, sliceParams...).WithContext(ctx).Iter()
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
|
||||
// Scan results into a map and append to the slice
|
||||
for {
|
||||
row := make(map[string]interface{}) // Create a new map for each row
|
||||
if !iter.MapScan(row) {
|
||||
break // No more rows
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse rows: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Manifest implements tools.Tool.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest implements tools.Tool.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// ParseParams implements tools.Tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
171
internal/tools/cassandra/cassandracql/cassandracql_test.go
Normal file
171
internal/tools/cassandra/cassandracql/cassandracql_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandracql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCassandra(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with template parameters",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: some description.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "without optional fields",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: nil,
|
||||
TemplateParameters: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
284
tests/cassandra/cassandra_integration_test.go
Normal file
284
tests/cassandra/cassandra_integration_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
CassandraSourceKind = "cassandra"
|
||||
CassandraToolKind = "cassandra-cql"
|
||||
Hosts = os.Getenv("CASSANDRA_HOST")
|
||||
Keyspace = "example_keyspace"
|
||||
Username = os.Getenv("CASSANDRA_USER")
|
||||
Password = os.Getenv("CASSANDRA_PASS")
|
||||
)
|
||||
|
||||
func getCassandraVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case Hosts:
|
||||
t.Fatal("'Hosts' not set")
|
||||
case Username:
|
||||
t.Fatal("'Username' not set")
|
||||
case Password:
|
||||
t.Fatal("'Password' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": CassandraSourceKind,
|
||||
"hosts": strings.Split(Hosts, ","),
|
||||
"keyspace": Keyspace,
|
||||
"username": Username,
|
||||
"password": Password,
|
||||
}
|
||||
}
|
||||
|
||||
func initCassandraSession() (*gocql.Session, error) {
|
||||
hostStrings := strings.Split(Hosts, ",")
|
||||
|
||||
var hosts []string
|
||||
for _, h := range hostStrings {
|
||||
trimmedHost := strings.TrimSpace(h)
|
||||
if trimmedHost != "" {
|
||||
hosts = append(hosts, trimmedHost)
|
||||
}
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return nil, fmt.Errorf("no valid hosts found in CASSANDRA_HOSTS env var")
|
||||
}
|
||||
// Configure cluster connection
|
||||
cluster := gocql.NewCluster(hosts...)
|
||||
cluster.Consistency = gocql.Quorum
|
||||
cluster.ProtoVersion = 4
|
||||
cluster.DisableInitialHostLookup = true
|
||||
cluster.ConnectTimeout = 10 * time.Second
|
||||
cluster.NumConns = 2
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: Username,
|
||||
Password: Password,
|
||||
}
|
||||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
||||
NumRetries: 3,
|
||||
Min: 200 * time.Millisecond,
|
||||
Max: 2 * time.Second,
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Create keyspace
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
CREATE KEYSPACE IF NOT EXISTS %s
|
||||
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}
|
||||
`, Keyspace)).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create keyspace: %v", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func initTable(tableName string, session *gocql.Session) error {
|
||||
|
||||
// Create table with additional columns
|
||||
err := session.Query(fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s.%s (
|
||||
id int PRIMARY KEY,
|
||||
name text,
|
||||
email text,
|
||||
age int,
|
||||
is_active boolean,
|
||||
created_at timestamp
|
||||
)
|
||||
`, Keyspace, tableName)).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
// Use fixed timestamps for reproducibility
|
||||
fixedTime, _ := time.Parse(time.RFC3339, "2025-07-25T12:00:00Z")
|
||||
dayAgo := fixedTime.Add(-24 * time.Hour)
|
||||
twelveHoursAgo := fixedTime.Add(-12 * time.Hour)
|
||||
|
||||
// Insert minimal diverse data with fixed time.Time for timestamps
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
3, "Alice", tests.ServiceAccountEmail, 25, true, dayAgo,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
2, "Alex", "janedoe@gmail.com", 30, false, twelveHoursAgo,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
1, "Sid", "sid@gmail.com", 10, true, fixedTime,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
4, nil, "a@gmail.com", 40, false, fixedTime,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dropTable(session *gocql.Session, tableName string) {
|
||||
err := session.Query(fmt.Sprintf("drop table %s.%s", Keyspace, tableName)).Exec()
|
||||
if err != nil {
|
||||
log.Printf("Failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCassandra(t *testing.T) {
|
||||
session, err := initCassandraSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer session.Close()
|
||||
sourceConfig := getCassandraVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
paramTableName := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
err = initTable(paramTableName, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, paramTableName)
|
||||
|
||||
err = initTable(tableNameAuth, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, tableNameAuth)
|
||||
|
||||
err = initTable(tableNameTemplateParam, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, tableNameTemplateParam)
|
||||
|
||||
paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt := createParamToolInfo(paramTableName)
|
||||
_, _, authToolStmt := getCassandraAuthToolInfo(tableNameAuth)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CassandraToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getCassandraTmplToolInfo()
|
||||
tmpSelectAll := "SELECT * FROM {{.tableName}} where id = 1"
|
||||
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CassandraToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmpSelectAll)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, mcpSelect1Want, mcpMyToolIdWant := getCassandraWants()
|
||||
selectAllWant, selectIdWant, selectNameWant := getCassandraTmplWants()
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test(),
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolId3NameAliceWant(selectIdNameWant),
|
||||
tests.WithMyToolById4Want(selectIdNullWant),
|
||||
tests.WithMyArrayToolWant(selectArrayParamWant),
|
||||
tests.DisableSelect1AuthTest())
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam,
|
||||
tests.DisableSelectFilterTest(),
|
||||
tests.WithSelectAllWant(selectAllWant),
|
||||
tests.DisableDdlTest(), tests.DisableInsertTest(), tests.WithTmplSelectId1Want(selectIdWant), tests.WithTmplSelectNameWant(selectNameWant))
|
||||
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want,
|
||||
tests.WithMcpMyToolId3NameAliceWant(mcpMyToolIdWant),
|
||||
tests.DisableMcpSelect1AuthTest())
|
||||
|
||||
}
|
||||
|
||||
func createParamToolInfo(tableName string) (string, string, string, string) {
|
||||
toolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id = ? AND name = ? ALLOW FILTERING;", tableName)
|
||||
idParamStatement := fmt.Sprintf("SELECT id,name FROM %s WHERE id = ?;", tableName)
|
||||
nameParamStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE name = ? ALLOW FILTERING;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id IN ? AND name IN ? ALLOW FILTERING;", tableName)
|
||||
return toolStatement, idParamStatement, nameParamStatement, arrayToolStatement
|
||||
|
||||
}
|
||||
|
||||
func getCassandraAuthToolInfo(tableName string) (string, string, string) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id UUID PRIMARY KEY, name TEXT, email TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (uuid(), ?, ?), (uuid(), ?, ?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ? ALLOW FILTERING;", tableName)
|
||||
return createStatement, insertStatement, toolStatement
|
||||
}
|
||||
|
||||
func getCassandraTmplToolInfo() (string, string) {
|
||||
selectAllTemplateStmt := "SELECT age, id, name FROM {{.tableName}} where id = ?;"
|
||||
selectByIdTemplateStmt := "SELECT id, name FROM {{.tableName}} WHERE name = ? ALLOW FILTERING;"
|
||||
return selectAllTemplateStmt, selectByIdTemplateStmt
|
||||
}
|
||||
|
||||
func getCassandraWants() (string, string, string, string, string, string) {
|
||||
selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]"
|
||||
selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]"
|
||||
selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]"
|
||||
mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}"
|
||||
mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}"
|
||||
return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant
|
||||
}
|
||||
|
||||
func getCassandraTmplWants() (string, string, string) {
|
||||
selectAllWant := "[{\"age\":10,\"created_at\":\"2025-07-25T12:00:00Z\",\"email\":\"sid@gmail.com\",\"id\":1,\"is_active\":true,\"name\":\"Sid\"}]"
|
||||
selectIdWant := "[{\"age\":10,\"id\":1,\"name\":\"Sid\"}]"
|
||||
selectNameWant := "[{\"id\":2,\"name\":\"Alex\"}]"
|
||||
return selectAllWant, selectIdWant, selectNameWant
|
||||
}
|
||||
@@ -110,6 +110,7 @@ func TestMongoDBToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(myToolId3NameAliceWant),
|
||||
tests.WithMyArrayToolWant(myToolId3NameAliceWant),
|
||||
tests.WithMyToolById4Want(myToolById4Want),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, select1Want,
|
||||
|
||||
@@ -21,9 +21,12 @@ type InvokeTestConfig struct {
|
||||
myToolId3NameAliceWant string
|
||||
myToolById4Want string
|
||||
nullWant string
|
||||
myArrayToolWant string
|
||||
supportSelect1Want bool
|
||||
supportOptionalNullParam bool
|
||||
supportArrayParam bool
|
||||
supportClientAuth bool
|
||||
supportSelect1Auth bool
|
||||
}
|
||||
|
||||
type InvokeTestOption func(*InvokeTestConfig)
|
||||
@@ -36,6 +39,14 @@ func WithMyToolId3NameAliceWant(s string) InvokeTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMyArrayToolWant represents the response value for my-array-tool.
|
||||
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyArrayToolWant("custom"))
|
||||
func WithMyArrayToolWant(s string) InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.myArrayToolWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithMyToolById4Want represents the response value for my-tool-by-id with id=4.
|
||||
// This response includes a null value column.
|
||||
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolById4Want("custom"))
|
||||
@@ -69,6 +80,22 @@ func DisableArrayTest() InvokeTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// DisableSelect1Test disables tests for sources that do not support SELECT 1 query.
|
||||
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test())
|
||||
func DisableSelect1Test() InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.supportSelect1Want = false
|
||||
}
|
||||
}
|
||||
|
||||
// DisableSelect1AuthTest disables auth tests for sources that do not support SELECT 1 query.
|
||||
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1AuthTest())
|
||||
func DisableSelect1AuthTest() InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.supportSelect1Auth = false
|
||||
}
|
||||
}
|
||||
|
||||
// EnableClientAuthTest runs the client authorization tests.
|
||||
// Only enable it if your source supports the `useClientOAuth` configuration.
|
||||
// Currently, this should only be used with the BigQuery tests.
|
||||
@@ -84,6 +111,7 @@ func EnableClientAuthTest() InvokeTestOption {
|
||||
type MCPTestConfig struct {
|
||||
myToolId3NameAliceWant string
|
||||
supportClientAuth bool
|
||||
supportSelect1Auth bool
|
||||
}
|
||||
|
||||
type McpTestOption func(*MCPTestConfig)
|
||||
@@ -105,6 +133,13 @@ func EnableMcpClientAuthTest() McpTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// DisableMcpSelect1AuthTest disables the auth tool tests which use select 1.
|
||||
func DisableMcpSelect1AuthTest() McpTestOption {
|
||||
return func(c *MCPTestConfig) {
|
||||
c.supportSelect1Auth = false
|
||||
}
|
||||
}
|
||||
|
||||
/* Configurations for RunExecuteSqlToolInvokeTest() */
|
||||
|
||||
// ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest()
|
||||
@@ -129,6 +164,7 @@ type TemplateParameterTestConfig struct {
|
||||
ddlWant string
|
||||
selectAllWant string
|
||||
selectId1Want string
|
||||
selectNameWant string
|
||||
selectEmptyWant string
|
||||
insert1Want string
|
||||
|
||||
@@ -136,8 +172,9 @@ type TemplateParameterTestConfig struct {
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
|
||||
supportDdl bool
|
||||
supportInsert bool
|
||||
supportDdl bool
|
||||
supportInsert bool
|
||||
supportSelectFields bool
|
||||
}
|
||||
|
||||
type TemplateParamOption func(*TemplateParameterTestConfig)
|
||||
@@ -166,6 +203,14 @@ func WithTmplSelectId1Want(s string) TemplateParamOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTmplSelectNameWant represents the response value of select-filter-templateParams-combined-tool with name.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithTmplSelectNameWant("custom"))
|
||||
func WithTmplSelectNameWant(s string) TemplateParamOption {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.selectNameWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectEmptyWant represents the response value of select-templateParams-combined-tool with no results.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithSelectEmptyWant("custom"))
|
||||
func WithSelectEmptyWant(s string) TemplateParamOption {
|
||||
@@ -221,3 +266,11 @@ func DisableInsertTest() TemplateParamOption {
|
||||
c.supportInsert = false
|
||||
}
|
||||
}
|
||||
|
||||
// DisableInsertTest disables tests of select-fields-templateParams-tool test.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.DisableSelectFilterTest())
|
||||
func DisableSelectFilterTest() TemplateParamOption {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.supportSelectFields = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(invokeIdNullWant),
|
||||
tests.WithNullWant(nullWant),
|
||||
)
|
||||
|
||||
@@ -164,6 +164,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(toolInvokeMyToolById4Want),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolId3NameAliceWant))
|
||||
|
||||
@@ -257,10 +257,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
configs := &InvokeTestConfig{
|
||||
myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
||||
myToolById4Want: "[{\"id\":4,\"name\":null}]",
|
||||
myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
||||
nullWant: "null",
|
||||
supportOptionalNullParam: true,
|
||||
supportArrayParam: true,
|
||||
supportClientAuth: false,
|
||||
supportSelect1Want: true,
|
||||
supportSelect1Auth: true,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -294,7 +297,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Want,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: select1Want,
|
||||
@@ -351,13 +354,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
enabled: configs.supportArrayParam,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
|
||||
wantBody: configs.myToolId3NameAliceWant,
|
||||
wantBody: configs.myArrayToolWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: "[{\"name\":\"Alice\"}]",
|
||||
@@ -366,7 +369,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
@@ -382,7 +385,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
|
||||
@@ -491,6 +494,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
ddlWant: "null",
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectEmptyWant: "null",
|
||||
insert1Want: "null",
|
||||
|
||||
@@ -512,6 +516,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
ddl bool
|
||||
insert bool
|
||||
api string
|
||||
@@ -573,6 +578,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
},
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
enabled: configs.supportSelectFields,
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))),
|
||||
@@ -584,7 +590,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))),
|
||||
want: configs.selectId1Want,
|
||||
want: configs.selectNameWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -599,6 +605,9 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.enabled {
|
||||
return
|
||||
}
|
||||
// if test case is DDL and source support ddl test cases
|
||||
ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl)
|
||||
// if test case is insert statement and source support insert test cases
|
||||
@@ -834,6 +843,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
|
||||
configs := &MCPTestConfig{
|
||||
myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
|
||||
supportClientAuth: false,
|
||||
supportSelect1Auth: true,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -947,7 +957,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
|
||||
{
|
||||
name: "MCP Invoke my-auth-required-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
|
||||
@@ -107,6 +107,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(invokeIdNullWant),
|
||||
tests.WithNullWant(nullWant),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user