mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-13 01:18:19 -05:00
Compare commits
8 Commits
sts-python
...
enum-regex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9af7970322 | ||
|
|
d96fceacd2 | ||
|
|
6e420534ee | ||
|
|
6387dd3efa | ||
|
|
a63d7aae2b | ||
|
|
4be5d165de | ||
|
|
a5ef166fcb | ||
|
|
8c268b20ae |
@@ -662,6 +662,26 @@ steps:
|
||||
- |
|
||||
./yugabytedb.test -test.v
|
||||
|
||||
|
||||
- id: "cassandra"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID", "CASSANDRA_USER", "CASSANDRA_PASS", "CASSANDRA_HOST"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cassandra" \
|
||||
cassandra \
|
||||
cassandra
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -746,6 +766,12 @@ availableSecrets:
|
||||
env: YUGABYTEDB_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/yugabytedb_pass/versions/latest
|
||||
env: YUGABYTEDB_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_user/versions/latest
|
||||
env: CASSANDRA_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_pass/versions/latest
|
||||
env: CASSANDRA_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cassandra_host/versions/latest
|
||||
env: CASSANDRA_HOST
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
|
||||
47
.ci/quickstart_test/go.integration.cloudbuild.yaml
Normal file
47
.ci/quickstart_test/go.integration.cloudbuild.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
steps:
|
||||
- name: 'golang:1.25.1'
|
||||
id: 'go-quickstart-test'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
# The '-c' flag tells bash to execute the following string as a command.
|
||||
# The 'set -ex' enables debug output and exits on error for easier troubleshooting.
|
||||
- -c
|
||||
- |
|
||||
set -ex
|
||||
export VERSION=$(cat ./cmd/version.txt)
|
||||
chmod +x .ci/quickstart_test/run_go_tests.sh
|
||||
.ci/quickstart_test/run_go_tests.sh
|
||||
env:
|
||||
- 'CLOUD_SQL_INSTANCE=${_CLOUD_SQL_INSTANCE}'
|
||||
- 'GCP_PROJECT=${_GCP_PROJECT}'
|
||||
- 'DATABASE_NAME=${_DATABASE_NAME}'
|
||||
- 'DB_USER=${_DB_USER}'
|
||||
secretEnv: ['TOOLS_YAML_CONTENT', 'GOOGLE_API_KEY', 'DB_PASSWORD']
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_TOOLS_YAML_SECRET}/versions/7
|
||||
env: 'TOOLS_YAML_CONTENT'
|
||||
- versionName: projects/${_GCP_PROJECT_NUMBER}/secrets/${_API_KEY_SECRET}/versions/latest
|
||||
env: 'GOOGLE_API_KEY'
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_DB_PASS_SECRET}/versions/latest
|
||||
env: 'DB_PASSWORD'
|
||||
|
||||
timeout: 1000s
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
47
.ci/quickstart_test/js.integration.cloudbuild.yaml
Normal file
47
.ci/quickstart_test/js.integration.cloudbuild.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
steps:
|
||||
- name: 'node:20'
|
||||
id: 'js-quickstart-test'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
# The '-c' flag tells bash to execute the following string as a command.
|
||||
# The 'set -ex' enables debug output and exits on error for easier troubleshooting.
|
||||
- -c
|
||||
- |
|
||||
set -ex
|
||||
export VERSION=$(cat ./cmd/version.txt)
|
||||
chmod +x .ci/quickstart_test/run_js_tests.sh
|
||||
.ci/quickstart_test/run_js_tests.sh
|
||||
env:
|
||||
- 'CLOUD_SQL_INSTANCE=${_CLOUD_SQL_INSTANCE}'
|
||||
- 'GCP_PROJECT=${_GCP_PROJECT}'
|
||||
- 'DATABASE_NAME=${_DATABASE_NAME}'
|
||||
- 'DB_USER=${_DB_USER}'
|
||||
secretEnv: ['TOOLS_YAML_CONTENT', 'GOOGLE_API_KEY', 'DB_PASSWORD']
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_TOOLS_YAML_SECRET}/versions/6
|
||||
env: 'TOOLS_YAML_CONTENT'
|
||||
- versionName: projects/${_GCP_PROJECT_NUMBER}/secrets/${_API_KEY_SECRET}/versions/latest
|
||||
env: 'GOOGLE_API_KEY'
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_DB_PASS_SECRET}/versions/latest
|
||||
env: 'DB_PASSWORD'
|
||||
|
||||
timeout: 1000s
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
47
.ci/quickstart_test/py.integration.cloudbuild.yaml
Normal file
47
.ci/quickstart_test/py.integration.cloudbuild.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
steps:
|
||||
- name: 'gcr.io/google.com/cloudsdktool/cloud-sdk:537.0.0'
|
||||
id: 'python-quickstart-test'
|
||||
entrypoint: 'bash'
|
||||
args:
|
||||
# The '-c' flag tells bash to execute the following string as a command.
|
||||
# The 'set -ex' enables debug output and exits on error for easier troubleshooting.
|
||||
- -c
|
||||
- |
|
||||
set -ex
|
||||
export VERSION=$(cat ./cmd/version.txt)
|
||||
chmod +x .ci/quickstart_test/run_py_tests.sh
|
||||
.ci/quickstart_test/run_py_tests.sh
|
||||
env:
|
||||
- 'CLOUD_SQL_INSTANCE=${_CLOUD_SQL_INSTANCE}'
|
||||
- 'GCP_PROJECT=${_GCP_PROJECT}'
|
||||
- 'DATABASE_NAME=${_DATABASE_NAME}'
|
||||
- 'DB_USER=${_DB_USER}'
|
||||
secretEnv: ['TOOLS_YAML_CONTENT', 'GOOGLE_API_KEY', 'DB_PASSWORD']
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_TOOLS_YAML_SECRET}/versions/5
|
||||
env: 'TOOLS_YAML_CONTENT'
|
||||
- versionName: projects/${_GCP_PROJECT_NUMBER}/secrets/${_API_KEY_SECRET}/versions/latest
|
||||
env: 'GOOGLE_API_KEY'
|
||||
- versionName: projects/${_GCP_PROJECT}/secrets/${_DB_PASS_SECRET}/versions/latest
|
||||
env: 'DB_PASSWORD'
|
||||
|
||||
timeout: 1000s
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
125
.ci/quickstart_test/run_go_tests.sh
Normal file
125
.ci/quickstart_test/run_go_tests.sh
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
TABLE_NAME="hotels_go"
|
||||
QUICKSTART_GO_DIR="docs/en/getting-started/quickstart/go"
|
||||
SQL_FILE=".ci/quickstart_test/setup_hotels_sample.sql"
|
||||
|
||||
PROXY_PID=""
|
||||
TOOLBOX_PID=""
|
||||
|
||||
install_system_packages() {
|
||||
apt-get update && apt-get install -y \
|
||||
postgresql-client \
|
||||
wget \
|
||||
gettext-base \
|
||||
netcat-openbsd
|
||||
}
|
||||
|
||||
start_cloud_sql_proxy() {
|
||||
wget "https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.10.0/cloud-sql-proxy.linux.amd64" -O /usr/local/bin/cloud-sql-proxy
|
||||
chmod +x /usr/local/bin/cloud-sql-proxy
|
||||
cloud-sql-proxy "${CLOUD_SQL_INSTANCE}" &
|
||||
PROXY_PID=$!
|
||||
|
||||
for i in {1..30}; do
|
||||
if nc -z 127.0.0.1 5432; then
|
||||
echo "Cloud SQL Proxy is up and running."
|
||||
return
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Cloud SQL Proxy failed to start within the timeout period."
|
||||
exit 1
|
||||
}
|
||||
|
||||
setup_toolbox() {
|
||||
TOOLBOX_YAML="/tools.yaml"
|
||||
echo "${TOOLS_YAML_CONTENT}" > "$TOOLBOX_YAML"
|
||||
if [ ! -f "$TOOLBOX_YAML" ]; then echo "Failed to create tools.yaml"; exit 1; fi
|
||||
wget "https://storage.googleapis.com/genai-toolbox/v${VERSION}/linux/amd64/toolbox" -O "/toolbox"
|
||||
chmod +x "/toolbox"
|
||||
/toolbox --tools-file "$TOOLBOX_YAML" &
|
||||
TOOLBOX_PID=$!
|
||||
sleep 2
|
||||
}
|
||||
|
||||
setup_orch_table() {
|
||||
export TABLE_NAME
|
||||
envsubst < "$SQL_FILE" | psql -h "$PGHOST" -p "$PGPORT" -U "$DB_USER" -d "$DATABASE_NAME"
|
||||
}
|
||||
|
||||
run_orch_test() {
|
||||
local orch_dir="$1"
|
||||
local orch_name
|
||||
orch_name=$(basename "$orch_dir")
|
||||
|
||||
if [ "$orch_name" == "openAI" ]; then
|
||||
echo -e "\nSkipping framework '${orch_name}': Temporarily excluded."
|
||||
return
|
||||
fi
|
||||
|
||||
(
|
||||
set -e
|
||||
setup_orch_table
|
||||
|
||||
echo "--- Preparing module for $orch_name ---"
|
||||
cd "$orch_dir"
|
||||
|
||||
if [ -f "go.mod" ]; then
|
||||
go mod tidy
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
export ORCH_NAME="$orch_name"
|
||||
|
||||
echo "--- Running tests for $orch_name ---"
|
||||
go test -v ./...
|
||||
)
|
||||
}
|
||||
|
||||
cleanup_all() {
|
||||
echo "--- Final cleanup: Shutting down processes and dropping table ---"
|
||||
if [ -n "$TOOLBOX_PID" ]; then
|
||||
kill $TOOLBOX_PID || true
|
||||
fi
|
||||
if [ -n "$PROXY_PID" ]; then
|
||||
kill $PROXY_PID || true
|
||||
fi
|
||||
}
|
||||
trap cleanup_all EXIT
|
||||
|
||||
# Main script execution
|
||||
install_system_packages
|
||||
start_cloud_sql_proxy
|
||||
|
||||
export PGHOST=127.0.0.1
|
||||
export PGPORT=5432
|
||||
export PGPASSWORD="$DB_PASSWORD"
|
||||
export GOOGLE_API_KEY="$GOOGLE_API_KEY"
|
||||
|
||||
setup_toolbox
|
||||
|
||||
for ORCH_DIR in "$QUICKSTART_GO_DIR"/*/; do
|
||||
if [ ! -d "$ORCH_DIR" ]; then
|
||||
continue
|
||||
fi
|
||||
run_orch_test "$ORCH_DIR"
|
||||
done
|
||||
123
.ci/quickstart_test/run_js_tests.sh
Normal file
123
.ci/quickstart_test/run_js_tests.sh
Normal file
@@ -0,0 +1,123 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
TABLE_NAME="hotels_js"
|
||||
QUICKSTART_JS_DIR="docs/en/getting-started/quickstart/js"
|
||||
SQL_FILE=".ci/quickstart_test/setup_hotels_sample.sql"
|
||||
|
||||
# Initialize process IDs to empty at the top of the script
|
||||
PROXY_PID=""
|
||||
TOOLBOX_PID=""
|
||||
|
||||
install_system_packages() {
|
||||
apt-get update && apt-get install -y \
|
||||
postgresql-client \
|
||||
wget \
|
||||
gettext-base \
|
||||
netcat-openbsd
|
||||
}
|
||||
|
||||
start_cloud_sql_proxy() {
|
||||
wget "https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.10.0/cloud-sql-proxy.linux.amd64" -O /usr/local/bin/cloud-sql-proxy
|
||||
chmod +x /usr/local/bin/cloud-sql-proxy
|
||||
cloud-sql-proxy "${CLOUD_SQL_INSTANCE}" &
|
||||
PROXY_PID=$!
|
||||
|
||||
for i in {1..30}; do
|
||||
if nc -z 127.0.0.1 5432; then
|
||||
echo "Cloud SQL Proxy is up and running."
|
||||
return
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Cloud SQL Proxy failed to start within the timeout period."
|
||||
exit 1
|
||||
}
|
||||
|
||||
setup_toolbox() {
|
||||
TOOLBOX_YAML="/tools.yaml"
|
||||
echo "${TOOLS_YAML_CONTENT}" > "$TOOLBOX_YAML"
|
||||
if [ ! -f "$TOOLBOX_YAML" ]; then echo "Failed to create tools.yaml"; exit 1; fi
|
||||
wget "https://storage.googleapis.com/genai-toolbox/v${VERSION}/linux/amd64/toolbox" -O "/toolbox"
|
||||
chmod +x "/toolbox"
|
||||
/toolbox --tools-file "$TOOLBOX_YAML" &
|
||||
TOOLBOX_PID=$!
|
||||
sleep 2
|
||||
}
|
||||
|
||||
setup_orch_table() {
|
||||
export TABLE_NAME
|
||||
envsubst < "$SQL_FILE" | psql -h "$PGHOST" -p "$PGPORT" -U "$DB_USER" -d "$DATABASE_NAME"
|
||||
}
|
||||
|
||||
run_orch_test() {
|
||||
local orch_dir="$1"
|
||||
local orch_name
|
||||
orch_name=$(basename "$orch_dir")
|
||||
|
||||
(
|
||||
set -e
|
||||
echo "--- Preparing environment for $orch_name ---"
|
||||
setup_orch_table
|
||||
|
||||
cd "$orch_dir"
|
||||
if [ -f "package.json" ]; then
|
||||
echo "Installing dependencies for $orch_name..."
|
||||
npm install
|
||||
fi
|
||||
|
||||
cd ..
|
||||
|
||||
echo "--- Running tests for $orch_name ---"
|
||||
export ORCH_NAME="$orch_name"
|
||||
node --test quickstart.test.js
|
||||
|
||||
echo "--- Cleaning environment for $orch_name ---"
|
||||
rm -rf "${orch_name}/node_modules"
|
||||
)
|
||||
}
|
||||
|
||||
cleanup_all() {
|
||||
echo "--- Final cleanup: Shutting down processes and dropping table ---"
|
||||
if [ -n "$TOOLBOX_PID" ]; then
|
||||
kill $TOOLBOX_PID || true
|
||||
fi
|
||||
if [ -n "$PROXY_PID" ]; then
|
||||
kill $PROXY_PID || true
|
||||
fi
|
||||
}
|
||||
trap cleanup_all EXIT
|
||||
|
||||
# Main script execution
|
||||
install_system_packages
|
||||
start_cloud_sql_proxy
|
||||
|
||||
export PGHOST=127.0.0.1
|
||||
export PGPORT=5432
|
||||
export PGPASSWORD="$DB_PASSWORD"
|
||||
export GOOGLE_API_KEY="$GOOGLE_API_KEY"
|
||||
|
||||
setup_toolbox
|
||||
|
||||
for ORCH_DIR in "$QUICKSTART_JS_DIR"/*/; do
|
||||
if [ ! -d "$ORCH_DIR" ]; then
|
||||
continue
|
||||
fi
|
||||
run_orch_test "$ORCH_DIR"
|
||||
done
|
||||
115
.ci/quickstart_test/run_py_tests.sh
Normal file
115
.ci/quickstart_test/run_py_tests.sh
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
TABLE_NAME="hotels_python"
|
||||
QUICKSTART_PYTHON_DIR="docs/en/getting-started/quickstart/python"
|
||||
SQL_FILE=".ci/quickstart_test/setup_hotels_sample.sql"
|
||||
|
||||
PROXY_PID=""
|
||||
TOOLBOX_PID=""
|
||||
|
||||
install_system_packages() {
|
||||
apt-get update && apt-get install -y \
|
||||
postgresql-client \
|
||||
python3-venv \
|
||||
wget \
|
||||
gettext-base \
|
||||
netcat-openbsd
|
||||
}
|
||||
|
||||
start_cloud_sql_proxy() {
|
||||
wget "https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.10.0/cloud-sql-proxy.linux.amd64" -O /usr/local/bin/cloud-sql-proxy
|
||||
chmod +x /usr/local/bin/cloud-sql-proxy
|
||||
cloud-sql-proxy "${CLOUD_SQL_INSTANCE}" &
|
||||
PROXY_PID=$!
|
||||
|
||||
for i in {1..30}; do
|
||||
if nc -z 127.0.0.1 5432; then
|
||||
echo "Cloud SQL Proxy is up and running."
|
||||
return
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Cloud SQL Proxy failed to start within the timeout period."
|
||||
exit 1
|
||||
}
|
||||
|
||||
setup_toolbox() {
|
||||
TOOLBOX_YAML="/tools.yaml"
|
||||
echo "${TOOLS_YAML_CONTENT}" > "$TOOLBOX_YAML"
|
||||
if [ ! -f "$TOOLBOX_YAML" ]; then echo "Failed to create tools.yaml"; exit 1; fi
|
||||
wget "https://storage.googleapis.com/genai-toolbox/v${VERSION}/linux/amd64/toolbox" -O "/toolbox"
|
||||
chmod +x "/toolbox"
|
||||
/toolbox --tools-file "$TOOLBOX_YAML" &
|
||||
TOOLBOX_PID=$!
|
||||
sleep 2
|
||||
}
|
||||
|
||||
setup_orch_table() {
|
||||
export TABLE_NAME
|
||||
envsubst < "$SQL_FILE" | psql -h "$PGHOST" -p "$PGPORT" -U "$DB_USER" -d "$DATABASE_NAME"
|
||||
}
|
||||
|
||||
run_orch_test() {
|
||||
local orch_dir="$1"
|
||||
local orch_name
|
||||
orch_name=$(basename "$orch_dir")
|
||||
(
|
||||
set -e
|
||||
setup_orch_table
|
||||
cd "$orch_dir"
|
||||
local VENV_DIR=".venv"
|
||||
python3 -m venv "$VENV_DIR"
|
||||
source "$VENV_DIR/bin/activate"
|
||||
pip install -r requirements.txt
|
||||
echo "--- Running tests for $orch_name ---"
|
||||
cd ..
|
||||
ORCH_NAME="$orch_name" pytest
|
||||
rm -rf "$VENV_DIR"
|
||||
)
|
||||
}
|
||||
|
||||
cleanup_all() {
|
||||
echo "--- Final cleanup: Shutting down processes and dropping table ---"
|
||||
if [ -n "$TOOLBOX_PID" ]; then
|
||||
kill $TOOLBOX_PID || true
|
||||
fi
|
||||
if [ -n "$PROXY_PID" ]; then
|
||||
kill $PROXY_PID || true
|
||||
fi
|
||||
}
|
||||
trap cleanup_all EXIT
|
||||
|
||||
# Main script execution
|
||||
install_system_packages
|
||||
start_cloud_sql_proxy
|
||||
|
||||
export PGHOST=127.0.0.1
|
||||
export PGPORT=5432
|
||||
export PGPASSWORD="$DB_PASSWORD"
|
||||
export GOOGLE_API_KEY="$GOOGLE_API_KEY"
|
||||
|
||||
setup_toolbox
|
||||
|
||||
for ORCH_DIR in "$QUICKSTART_PYTHON_DIR"/*/; do
|
||||
if [ ! -d "$ORCH_DIR" ]; then
|
||||
continue
|
||||
fi
|
||||
run_orch_test "$ORCH_DIR"
|
||||
done
|
||||
28
.ci/quickstart_test/setup_hotels_sample.sql
Normal file
28
.ci/quickstart_test/setup_hotels_sample.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
-- Copyright 2025 Google LLC
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
TRUNCATE TABLE $TABLE_NAME;
|
||||
|
||||
INSERT INTO $TABLE_NAME (id, name, location, price_tier, checkin_date, checkout_date, booked)
|
||||
VALUES
|
||||
(1, 'Hilton Basel', 'Basel', 'Luxury', '2024-04-22', '2024-04-20', B'0'),
|
||||
(2, 'Marriott Zurich', 'Zurich', 'Upscale', '2024-04-14', '2024-04-21', B'0'),
|
||||
(3, 'Hyatt Regency Basel', 'Basel', 'Upper Upscale', '2024-04-02', '2024-04-20', B'0'),
|
||||
(4, 'Radisson Blu Lucerne', 'Lucerne', 'Midscale', '2024-04-24', '2024-04-05', B'0'),
|
||||
(5, 'Best Western Bern', 'Bern', 'Upper Midscale', '2024-04-23', '2024-04-01', B'0'),
|
||||
(6, 'InterContinental Geneva', 'Geneva', 'Luxury', '2024-04-23', '2024-04-28', B'0'),
|
||||
(7, 'Sheraton Zurich', 'Zurich', 'Upper Upscale', '2024-04-27', '2024-04-02', B'0'),
|
||||
(8, 'Holiday Inn Basel', 'Basel', 'Upper Midscale', '2024-04-24', '2024-04-09', B'0'),
|
||||
(9, 'Courtyard Zurich', 'Zurich', 'Upscale', '2024-04-03', '2024-04-13', B'0'),
|
||||
(10, 'Comfort Inn Bern', 'Bern', 'Midscale', '2024-04-04', '2024-04-16', B'0');
|
||||
@@ -64,6 +64,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||
@@ -159,6 +160,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/clickhouse"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudmonitoring"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
|
||||
@@ -11,6 +11,7 @@ import os
|
||||
# TODO(developer): replace this with your Google API key
|
||||
|
||||
api_key = os.environ.get("GOOGLE_API_KEY") or "your-api-key" # Set your API key here
|
||||
os.environ["GOOGLE_API_KEY"] = api_key
|
||||
|
||||
async def main():
|
||||
with ToolboxSyncClient("http://127.0.0.1:5000") as toolbox_client:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
google-adk==1.14.1
|
||||
toolbox-core==0.5.0
|
||||
toolbox-core==0.5.2
|
||||
pytest==8.4.2
|
||||
@@ -1,3 +1,3 @@
|
||||
google-genai==1.38.0
|
||||
toolbox-core==0.5.0
|
||||
toolbox-core==0.5.2
|
||||
pytest==8.4.2
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
llama-index==0.14.2
|
||||
llama-index-llms-google-genai==0.5.0
|
||||
llama-index-llms-google-genai==0.5.1
|
||||
toolbox-llamaindex==0.5.2
|
||||
pytest==8.4.2
|
||||
|
||||
57
docs/en/resources/sources/cassandra.md
Normal file
57
docs/en/resources/sources/cassandra.md
Normal file
@@ -0,0 +1,57 @@
|
||||
---
|
||||
title: "Cassandra"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Cassandra is a NoSQL distributed database known for its horizontal scalability, distributed architecture, and flexible schema definition.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[Cassandra][cassandra-docs] is a NoSQL distributed database. By design, NoSQL databases are lightweight, open-source, non-relational, and largely distributed. Counted among their strengths are horizontal scalability, distributed architectures, and a flexible approach to schema definition.
|
||||
|
||||
[cassandra-docs]: https://cassandra.apache.org/
|
||||
|
||||
## Available Tools
|
||||
|
||||
- [`cassandra-cql`](../tools/cassandra/cassandra-cql.md)
|
||||
Run parameterized CQL queries in Cassandra.
|
||||
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-cassandra-source:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- 127.0.0.1
|
||||
keyspace: my_keyspace
|
||||
protoVersion: 4
|
||||
username: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
caPath: /path/to/ca.crt # Optional: path to CA certificate
|
||||
certPath: /path/to/client.crt # Optional: path to client certificate
|
||||
keyPath: /path/to/client.key # Optional: path to client key
|
||||
enableHostVerification: true # Optional: enable host verification
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|------------------------|:---------:|:------------:|-------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cassandra". |
|
||||
| hosts | string[] | true | List of IP addresses to connect to (e.g., ["192.168.1.1:9042", "192.168.1.2:9042","192.168.1.3:9042"]). The default port is 9042 if not specified. |
|
||||
| keyspace | string | true | Name of the Cassandra keyspace to connect to (e.g., "my_keyspace"). |
|
||||
| protoVersion | integer | false | Protocol version for the Cassandra connection (e.g., 4). |
|
||||
| username | string | false | Name of the Cassandra user to connect as (e.g., "my-cassandra-user"). |
|
||||
| password | string | false | Password of the Cassandra user (e.g., "my-password"). |
|
||||
| caPath | string | false | Path to the CA certificate for SSL/TLS (e.g., "/path/to/ca.crt"). |
|
||||
| certPath | string | false | Path to the client certificate for SSL/TLS (e.g., "/path/to/client.crt"). |
|
||||
| keyPath | string | false | Path to the client key for SSL/TLS (e.g., "/path/to/client.key"). |
|
||||
| enableHostVerification | boolean | false | Enable host verification for SSL/TLS (e.g., true). By default, host verification is disabled. |
|
||||
@@ -153,6 +153,32 @@ will be thrown in case of value type mismatch.
|
||||
valueType: integer # This enforces the value type for all entries.
|
||||
```
|
||||
|
||||
### Enum Parameters
|
||||
|
||||
The `enum` type allow users to specify a set of allowed values with that
|
||||
parameter. When toolbox parse the input of parameters, it will check against the
|
||||
allowed values.
|
||||
|
||||
```yaml
|
||||
parameter:
|
||||
- name: airline
|
||||
type: enum
|
||||
description: name of airline.
|
||||
enumType: string
|
||||
allowedValues:
|
||||
- cymbalair
|
||||
- delta
|
||||
```
|
||||
|
||||
Other than the regular fields required with the `enumType` specified, below are
|
||||
the additional fields that are needed when using `enum` type.
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|---------------|:--------:|:------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| enumType | string | true | This indicates the type of the value. Must be one of the supported parameter type (e.g. `string`/ `integer` / `float` / `boolean` / `array` / `map`). |
|
||||
| escape | bool | false | Indicates if the value will be escaped if used with `templateParameters`. Escaping will add double quotes (or backticks/square brackets depending on the source) depending on the database. This is defaulted to `false`. |
|
||||
| allowedValues | []string | true | Input value will be checked against this field. |
|
||||
|
||||
### Authenticated Parameters
|
||||
|
||||
Authenticated parameters are automatically populated with user
|
||||
|
||||
7
docs/en/resources/tools/cassandra/_index.md
Normal file
7
docs/en/resources/tools/cassandra/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Cassandra"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Cassandra Sources.
|
||||
---
|
||||
96
docs/en/resources/tools/cassandra/cassandra-cql.md
Normal file
96
docs/en/resources/tools/cassandra/cassandra-cql.md
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
title: "cassandra-cql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cassandra-cql" tool executes a pre-defined CQL statement against a Cassandra
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/cassandra-cql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `cassandra-cql` tool executes a pre-defined CQL statement against a Cassandra
|
||||
database. It's compatible with any of the following sources:
|
||||
|
||||
- [cassandra](../sources/cassandra.md)
|
||||
|
||||
The specified CQL statement is executed as a [prepared statement][cassandra-prepare],
|
||||
and expects parameters in the CQL query to be in the form of placeholders `?`.
|
||||
|
||||
[cassandra-prepare]: https://docs.datastax.com/en/developer/go-driver/4.8/cql-prepared-statements/
|
||||
|
||||
## Example
|
||||
|
||||
> **Note:** This tool uses parameterized queries to prevent CQL injections.
|
||||
> Query parameters can be used as substitutes for arbitrary expressions.
|
||||
> Parameters cannot be used as substitutes for keyspaces, table names, column names,
|
||||
> or other parts of the query.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_users_by_email:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-cluster
|
||||
statement: |
|
||||
SELECT user_id, email, first_name, last_name, created_at
|
||||
FROM users
|
||||
WHERE email = ?
|
||||
description: |
|
||||
Use this tool to retrieve specific user information by their email address.
|
||||
Takes an email address and returns user details including user ID, email,
|
||||
first name, last name, and account creation timestamp.
|
||||
Do NOT use this tool with a user ID or other identifiers.
|
||||
Example:
|
||||
{{
|
||||
"email": "user@example.com",
|
||||
}}
|
||||
parameters:
|
||||
- name: email
|
||||
type: string
|
||||
description: User's email address
|
||||
```
|
||||
|
||||
### Example with Template Parameters
|
||||
|
||||
> **Note:** This tool allows direct modifications to the CQL statement,
|
||||
> including keyspaces, table names, and column names. **This makes it more
|
||||
> vulnerable to CQL injections**. Using basic parameters only (see above) is
|
||||
> recommended for performance and safety reasons. For more details, please check
|
||||
> [templateParameters](../#template-parameters).
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_keyspace_table:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-cluster
|
||||
statement: |
|
||||
SELECT * FROM {{.keyspace}}.{{.tableName}};
|
||||
description: |
|
||||
Use this tool to list all information from a specific table in a keyspace.
|
||||
Example:
|
||||
{{
|
||||
"keyspace": "my_keyspace",
|
||||
"tableName": "users",
|
||||
}}
|
||||
templateParameters:
|
||||
- name: keyspace
|
||||
type: string
|
||||
description: Keyspace containing the table
|
||||
- name: tableName
|
||||
type: string
|
||||
description: Table to select from
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------------|:------------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cassandra-cql". |
|
||||
| source | string | true | Name of the source the CQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| statement | string | true | CQL statement to execute. |
|
||||
| authRequired | []string | false | List of authentication requirements for the source. |
|
||||
| parameters | [parameters](../#specifying-parameters) | false | List of [parameters](../#specifying-parameters) that will be inserted into the CQL statement. |
|
||||
| templateParameters | [templateParameters](../#template-parameters) | false | List of [templateParameters](../#template-parameters) that will be inserted into the CQL statement before executing prepared statement. |
|
||||
3
go.mod
3
go.mod
@@ -26,6 +26,7 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.27.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -115,6 +116,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
@@ -176,6 +178,7 @@ require (
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
|
||||
10
go.sum
10
go.sum
@@ -737,6 +737,10 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4p
|
||||
github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
|
||||
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
@@ -897,6 +901,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -1047,6 +1053,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVT
|
||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
|
||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
|
||||
@@ -2042,6 +2050,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
|
||||
134
internal/sources/cassandra/cassandra.go
Normal file
134
internal/sources/cassandra/cassandra.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cassandra"
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Hosts []string `yaml:"hosts" validate:"required"`
|
||||
Keyspace string `yaml:"keyspace"`
|
||||
ProtoVersion int `yaml:"protoVersion"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
CAPath string `yaml:"caPath"`
|
||||
CertPath string `yaml:"certPath"`
|
||||
KeyPath string `yaml:"keyPath"`
|
||||
EnableHostVerification bool `yaml:"enableHostVerification"`
|
||||
}
|
||||
|
||||
// Initialize implements sources.SourceConfig.
|
||||
func (c Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
session, err := initCassandraSession(ctx, tracer, c)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create session: %v", err)
|
||||
}
|
||||
s := &Source{
|
||||
Name: c.Name,
|
||||
Kind: SourceKind,
|
||||
Session: session,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// SourceConfigKind implements sources.SourceConfig.
|
||||
func (c Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Session *gocql.Session
|
||||
}
|
||||
|
||||
// CassandraSession implements cassandra.compatibleSource.
|
||||
func (s *Source) CassandraSession() *gocql.Session {
|
||||
return s.Session
|
||||
}
|
||||
|
||||
// SourceKind implements sources.Source.
|
||||
func (s Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
func initCassandraSession(ctx context.Context, tracer trace.Tracer, c Config) (*gocql.Session, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, c.Name)
|
||||
defer span.End()
|
||||
|
||||
// Validate authentication configuration
|
||||
if c.Password != "" && c.Username == "" {
|
||||
return nil, fmt.Errorf("invalid Cassandra configuration: password provided without a username")
|
||||
}
|
||||
|
||||
cluster := gocql.NewCluster(c.Hosts...)
|
||||
cluster.ProtoVersion = c.ProtoVersion
|
||||
cluster.Keyspace = c.Keyspace
|
||||
|
||||
// Configure authentication if username is provided
|
||||
if c.Username != "" {
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: c.Username,
|
||||
Password: c.Password,
|
||||
}
|
||||
}
|
||||
|
||||
// Configure SSL options if any are specified
|
||||
if c.CAPath != "" || c.CertPath != "" || c.KeyPath != "" || c.EnableHostVerification {
|
||||
cluster.SslOpts = &gocql.SslOptions{
|
||||
CaPath: c.CAPath,
|
||||
CertPath: c.CertPath,
|
||||
KeyPath: c.KeyPath,
|
||||
EnableHostVerification: c.EnableHostVerification,
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Cassandra session: %w", err)
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
158
internal/sources/cassandra/cassandra_test.go
Normal file
158
internal/sources/cassandra/cassandra_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCassandra(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example (without optional fields)",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "",
|
||||
Password: "",
|
||||
ProtoVersion: 0,
|
||||
CAPath: "",
|
||||
CertPath: "",
|
||||
KeyPath: "",
|
||||
Keyspace: "",
|
||||
EnableHostVerification: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with optional fields",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
hosts:
|
||||
- "my-host1"
|
||||
- "my-host2"
|
||||
username: "user"
|
||||
password: "pass"
|
||||
keyspace: "example_keyspace"
|
||||
protoVersion: 4
|
||||
caPath: "path/to/ca.crt"
|
||||
certPath: "path/to/cert"
|
||||
keyPath: "path/to/key"
|
||||
enableHostVerification: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-cassandra-instance": cassandra.Config{
|
||||
Name: "my-cassandra-instance",
|
||||
Kind: cassandra.SourceKind,
|
||||
Hosts: []string{"my-host1", "my-host2"},
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Keyspace: "example_keyspace",
|
||||
ProtoVersion: 4,
|
||||
CAPath: "path/to/ca.crt",
|
||||
CertPath: "path/to/cert",
|
||||
KeyPath: "path/to/key",
|
||||
EnableHostVerification: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
host:
|
||||
- "my-host"
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": [1:1] unknown field \"foo\"\n> 1 | foo: bar\n ^\n 2 | host:\n 3 | - my-host\n 4 | kind: cassandra",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-cassandra-instance:
|
||||
kind: cassandra
|
||||
`,
|
||||
err: "unable to parse source \"my-cassandra-instance\" as \"cassandra\": Key: 'Config.Hosts' Error:Field validation for 'Hosts' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
182
internal/tools/cassandra/cassandracql/cassandracql.go
Normal file
182
internal/tools/cassandra/cassandracql/cassandracql.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandracql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cassandra"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "cassandra-cql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CassandraSession() *gocql.Session
|
||||
}
|
||||
|
||||
var _ compatibleSource = &cassandra.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cassandra.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// Initialize implements tools.ToolConfig.
|
||||
func (c Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[c.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", c.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest, err := tools.ProcessParameters(c.TemplateParameters, c.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: c.Name,
|
||||
Description: c.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: c.Name,
|
||||
Kind: kind,
|
||||
Parameters: c.Parameters,
|
||||
TemplateParameters: c.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: c.Statement,
|
||||
AuthRequired: c.AuthRequired,
|
||||
Session: s.CassandraSession(),
|
||||
manifest: tools.Manifest{Description: c.Description, Parameters: paramManifest, AuthRequired: c.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// ToolConfigKind implements tools.ToolConfig.
|
||||
func (c Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Session *gocql.Session
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// RequiresClientAuthorization implements tools.Tool.
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Authorized implements tools.Tool.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
// Invoke implements tools.Tool.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
iter := t.Session.Query(newStatement, sliceParams...).WithContext(ctx).Iter()
|
||||
|
||||
// Create a slice to store the out
|
||||
var out []map[string]interface{}
|
||||
|
||||
// Scan results into a map and append to the slice
|
||||
for {
|
||||
row := make(map[string]interface{}) // Create a new map for each row
|
||||
if !iter.MapScan(row) {
|
||||
break // No more rows
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse rows: %w", err)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Manifest implements tools.Tool.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest implements tools.Tool.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// ParseParams implements tools.Tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
var _ tools.Tool = Tool{}
|
||||
171
internal/tools/cassandra/cassandracql/cassandracql_test.go
Normal file
171
internal/tools/cassandra/cassandracql/cassandracql_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandracql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cassandra/cassandracql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCassandra(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with template parameters",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: some description.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "some description."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "without optional fields",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: cassandra-cql
|
||||
source: my-cassandra-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM CQL_STATEMENT;
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cassandracql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "cassandra-cql",
|
||||
Source: "my-cassandra-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM CQL_STATEMENT;\n",
|
||||
AuthRequired: []string{},
|
||||
Parameters: nil,
|
||||
TemplateParameters: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
)
|
||||
|
||||
@@ -33,6 +35,7 @@ const (
|
||||
typeBool = "boolean"
|
||||
typeArray = "array"
|
||||
typeMap = "map"
|
||||
typeEnum = "enum"
|
||||
)
|
||||
|
||||
// ParamValues is an ordered list of ParamValue
|
||||
@@ -179,6 +182,11 @@ func GetParams(params Parameters, paramValuesMap map[string]any) (ParamValues, e
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing parameter %s", k)
|
||||
}
|
||||
if p.GetType() == typeEnum {
|
||||
if p.(*EnumParameter).GetEscape() {
|
||||
v = fmt.Sprintf(`"%s"`, v)
|
||||
}
|
||||
}
|
||||
resultParamValues = append(resultParamValues, ParamValue{Name: k, Value: v})
|
||||
}
|
||||
return resultParamValues, nil
|
||||
@@ -233,6 +241,7 @@ type Parameter interface {
|
||||
// but this is done to differentiate it from the fields in CommonParameter.
|
||||
GetName() string
|
||||
GetType() string
|
||||
GetDesc() string
|
||||
GetDefault() any
|
||||
GetRequired() bool
|
||||
GetAuthServices() []ParamAuthService
|
||||
@@ -278,7 +287,7 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
|
||||
|
||||
t, ok := p["type"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("parameter is missing 'type' field: %w", err)
|
||||
return nil, fmt.Errorf("parameter is missing 'type' field")
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(p)
|
||||
@@ -356,6 +365,17 @@ func parseParamFromDelayedUnmarshaler(ctx context.Context, u *util.DelayedUnmars
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
case typeEnum:
|
||||
a := &EnumParameter{}
|
||||
if err := dec.DecodeContext(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("unable to parse as %q: %w", t, err)
|
||||
}
|
||||
if a.AuthSources != nil {
|
||||
logger.WarnContext(ctx, "`authSources` is deprecated, use `authServices` for parameters instead")
|
||||
a.AuthServices = append(a.AuthServices, a.AuthSources...)
|
||||
a.AuthSources = nil
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
return nil, fmt.Errorf("%q is not valid type for a parameter", t)
|
||||
}
|
||||
@@ -427,6 +447,11 @@ func (p *CommonParameter) GetType() string {
|
||||
return p.Type
|
||||
}
|
||||
|
||||
// GetDesc returns the description specified for the Parameter.
|
||||
func (p *CommonParameter) GetDesc() string {
|
||||
return p.Desc
|
||||
}
|
||||
|
||||
// GetRequired returns the type specified for the Parameter.
|
||||
func (p *CommonParameter) GetRequired() bool {
|
||||
// parameters are defaulted to required
|
||||
@@ -1230,3 +1255,155 @@ func (p *MapParameter) McpManifest() ParameterMcpManifest {
|
||||
AdditionalProperties: additionalProperties,
|
||||
}
|
||||
}
|
||||
|
||||
// NewEnumParameter is a convenience function for initializing a EnumParameter.
|
||||
func NewEnumParameter(param Parameter, escape bool, allowedValues []any) *EnumParameter {
|
||||
d := param.GetDefault()
|
||||
r := param.GetRequired()
|
||||
return &EnumParameter{
|
||||
CommonParameter: CommonParameter{
|
||||
Name: param.GetName(),
|
||||
Type: typeEnum,
|
||||
Desc: param.GetDesc(),
|
||||
Required: &r,
|
||||
AuthServices: param.GetAuthServices(),
|
||||
},
|
||||
EnumType: param.GetType(),
|
||||
Escape: escape,
|
||||
AllowedValues: allowedValues,
|
||||
EnumItem: param,
|
||||
Default: &d,
|
||||
}
|
||||
}
|
||||
|
||||
// EnumParameter is a parameter that allow users to specify
|
||||
// allowedValues to provide a fixed set of values. This will
|
||||
// make parameter, especially templateParameter more secure and safe.
|
||||
type EnumParameter struct {
|
||||
CommonParameter `yaml:",inline"`
|
||||
Default *any `yaml:"default"`
|
||||
EnumType string `yaml:"enumType"`
|
||||
Escape bool `yaml:"escape"`
|
||||
AllowedValues []any `yaml:"allowedValues"`
|
||||
EnumItem Parameter
|
||||
}
|
||||
|
||||
// Ensure EnumParameter implements the Parameter interface.
|
||||
var _ Parameter = &EnumParameter{}
|
||||
|
||||
// UnmarshalYAML handles parsing the EnumParameter from YAML input.
|
||||
func (p *EnumParameter) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
var rawItem map[string]any
|
||||
if err := unmarshal(&rawItem); err != nil {
|
||||
return fmt.Errorf("error parsing enum parameter: %w", err)
|
||||
}
|
||||
|
||||
// extract enum parameter known fields
|
||||
enumType, ok := rawItem["enumType"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing 'enumType' field")
|
||||
}
|
||||
escape := false
|
||||
if v, ok := rawItem["escape"]; ok {
|
||||
if escape, ok = v.(bool); !ok {
|
||||
return fmt.Errorf("error parsing 'escape' field")
|
||||
}
|
||||
}
|
||||
allowedValues, ok := rawItem["allowedValues"].([]any)
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing 'allowedValues' field")
|
||||
}
|
||||
rawItem["type"] = enumType
|
||||
|
||||
// remove the extracted field from the map
|
||||
delete(rawItem, "enumType")
|
||||
delete(rawItem, "escape")
|
||||
delete(rawItem, "allowedValues")
|
||||
|
||||
// create a util.DelayedUnmarshaler from the remaining fields
|
||||
m, err := yaml.Marshal(rawItem)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling remaining fields from enum parameter")
|
||||
}
|
||||
var delayedUnmarshaler util.DelayedUnmarshaler
|
||||
if err = yaml.UnmarshalContext(ctx, m, &delayedUnmarshaler); err != nil {
|
||||
return fmt.Errorf("error unmarhaling into DelayedUnmarshaler")
|
||||
}
|
||||
parameter, err := parseParamFromDelayedUnmarshaler(ctx, &delayedUnmarshaler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d := parameter.GetDefault()
|
||||
r := parameter.GetRequired()
|
||||
|
||||
p.Default = &d
|
||||
p.CommonParameter = CommonParameter{
|
||||
Name: parameter.GetName(),
|
||||
Type: "enum",
|
||||
Desc: parameter.GetDesc(),
|
||||
Required: &r,
|
||||
AuthServices: parameter.GetAuthServices(),
|
||||
}
|
||||
p.EnumType = enumType
|
||||
p.Escape = escape
|
||||
p.AllowedValues = allowedValues
|
||||
p.EnumItem = parameter
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse validates and parses an incoming value for enum parameter.
|
||||
func (p *EnumParameter) Parse(v any) (any, error) {
|
||||
input := fmt.Sprintf("%v", v)
|
||||
var exists bool
|
||||
for _, av := range p.AllowedValues {
|
||||
target := fmt.Sprintf("%v", av)
|
||||
if MatchStringOrRegex(input, target) {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("unable to parse enum parameter: input is not part of allowed values")
|
||||
}
|
||||
|
||||
return p.EnumItem.Parse(v)
|
||||
}
|
||||
|
||||
// MatchStringOrRegex checks if the input matches the target
|
||||
func MatchStringOrRegex(input, target string) bool {
|
||||
re, err := regexp.Compile(target)
|
||||
if err != nil {
|
||||
return strings.Contains(input, target)
|
||||
}
|
||||
return re.MatchString(input)
|
||||
}
|
||||
|
||||
func (p *EnumParameter) GetAuthServices() []ParamAuthService {
|
||||
return p.AuthServices
|
||||
}
|
||||
|
||||
func (p *EnumParameter) GetDefault() any {
|
||||
if p.Default == nil {
|
||||
return nil
|
||||
}
|
||||
return *p.Default
|
||||
}
|
||||
|
||||
func (p *EnumParameter) GetEnumType() string {
|
||||
return p.EnumType
|
||||
}
|
||||
|
||||
func (p *EnumParameter) GetEscape() bool {
|
||||
return p.Escape
|
||||
}
|
||||
|
||||
// Manifest returns the manifest for the EnumParameter.
|
||||
func (p *EnumParameter) Manifest() ParameterManifest {
|
||||
return p.EnumItem.Manifest()
|
||||
}
|
||||
|
||||
// McpManifest returns the MCP manifest for EnumParameter.
|
||||
func (p *EnumParameter) McpManifest() ParameterMcpManifest {
|
||||
return p.EnumItem.McpManifest()
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package tools_test
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -351,6 +352,37 @@ func TestParametersMarshal(t *testing.T) {
|
||||
tools.NewMapParameter("my_generic_map", "this param is a generic map", ""),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum string",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "enum_string",
|
||||
"type": "enum",
|
||||
"enumType": "string",
|
||||
"description": "enum string parameter",
|
||||
"allowedValues": []any{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), false, []any{"foo", "bar"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum string with escape",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "enum_string",
|
||||
"type": "enum",
|
||||
"enumType": "string",
|
||||
"description": "enum string parameter",
|
||||
"allowedValues": []any{"foo", "bar"},
|
||||
"escape": true,
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum string parameter"), true, []any{"foo", "bar"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -669,6 +701,31 @@ func TestAuthParametersMarshal(t *testing.T) {
|
||||
tools.NewMapParameterWithAuth("my_map", "this param is a map of strings", "string", authServices),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum",
|
||||
in: []map[string]any{
|
||||
{
|
||||
"name": "enum_string",
|
||||
"type": "enum",
|
||||
"description": "enum of strings",
|
||||
"enumType": "string",
|
||||
"allowedValues": []any{"foo", "bar"},
|
||||
"authServices": []map[string]string{
|
||||
{
|
||||
"name": "my-google-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
{
|
||||
"name": "other-auth-service",
|
||||
"field": "user_id",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameterWithAuth("enum_string", "enum of strings", authServices), false, []any{"foo", "bar"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -784,6 +841,35 @@ func TestParametersParse(t *testing.T) {
|
||||
"my_bool": 1.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
},
|
||||
in: map[string]any{
|
||||
"enum_string": "foo",
|
||||
},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
|
||||
},
|
||||
{
|
||||
name: "enum not allowed",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameter("enum_string", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
},
|
||||
in: map[string]any{
|
||||
"enum_string": "invalid",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with integer",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewIntParameter("enum_int", "enum of int"), false, []any{"^[1-5]$"}),
|
||||
},
|
||||
in: map[string]any{
|
||||
"enum_int": 4,
|
||||
},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "enum_int", Value: 4}},
|
||||
},
|
||||
{
|
||||
name: "string default",
|
||||
params: tools.Parameters{
|
||||
@@ -824,6 +910,14 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}},
|
||||
},
|
||||
{
|
||||
name: "enum default",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameterWithDefault("enum_string", "foo", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
},
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: "foo"}},
|
||||
},
|
||||
{
|
||||
name: "string not required",
|
||||
params: tools.Parameters{
|
||||
@@ -856,6 +950,14 @@ func TestParametersParse(t *testing.T) {
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: nil}},
|
||||
},
|
||||
{
|
||||
name: "enum not required",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameterWithRequired("enum_string", "enum of strings", false), true, []any{"foo", "bar"}),
|
||||
},
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{tools.ParamValue{Name: "enum_string", Value: nil}},
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
params: tools.Parameters{
|
||||
@@ -1197,6 +1299,17 @@ func TestParamManifest(t *testing.T) {
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: true, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with string",
|
||||
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-enum",
|
||||
Type: "string",
|
||||
Required: true,
|
||||
Description: "enum of strings",
|
||||
AuthServices: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string default",
|
||||
in: tools.NewStringParameterWithDefault("foo-string", "foo", "bar"),
|
||||
@@ -1229,6 +1342,17 @@ func TestParamManifest(t *testing.T) {
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with string default",
|
||||
in: tools.NewEnumParameter(tools.NewStringParameterWithDefault("foo-enum", "foo", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-enum",
|
||||
Type: "string",
|
||||
Required: false,
|
||||
Description: "enum of strings",
|
||||
AuthServices: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string not required",
|
||||
in: tools.NewStringParameterWithRequired("foo-string", "bar", false),
|
||||
@@ -1261,6 +1385,17 @@ func TestParamManifest(t *testing.T) {
|
||||
Items: &tools.ParameterManifest{Name: "foo-string", Type: "string", Required: false, Description: "bar", AuthServices: []string{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with string not required",
|
||||
in: tools.NewEnumParameter(tools.NewStringParameterWithRequired("foo-enum", "enum of strings", false), false, []any{"foo", "bar"}),
|
||||
want: tools.ParameterManifest{
|
||||
Name: "foo-enum",
|
||||
Type: "string",
|
||||
Required: false,
|
||||
Description: "enum of strings",
|
||||
AuthServices: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "map with string values",
|
||||
in: tools.NewMapParameter("foo-map", "bar", "string"),
|
||||
@@ -1343,7 +1478,6 @@ func TestParamMcpManifest(t *testing.T) {
|
||||
Items: &tools.ParameterMcpManifest{Type: "string", Description: "bar"},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
name: "map with string values",
|
||||
in: tools.NewMapParameter("foo-map", "bar", "string"),
|
||||
@@ -1362,6 +1496,11 @@ func TestParamMcpManifest(t *testing.T) {
|
||||
AdditionalProperties: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum param",
|
||||
in: tools.NewEnumParameter(tools.NewStringParameter("foo-enum", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
want: tools.ParameterMcpManifest{Type: "string", Description: "enum of strings"},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -1389,6 +1528,7 @@ func TestMcpManifest(t *testing.T) {
|
||||
tools.NewArrayParameter("foo-array2", "bar", tools.NewStringParameter("foo-string", "bar")),
|
||||
tools.NewMapParameter("foo-map-int", "a map of ints", "integer"),
|
||||
tools.NewMapParameter("foo-map-any", "a map of any", ""),
|
||||
tools.NewEnumParameter(tools.NewStringParameter("foo-enum-string", "enum of strings"), false, []any{"foo", "bar"}),
|
||||
},
|
||||
want: tools.McpToolsSchema{
|
||||
Type: "object",
|
||||
@@ -1412,8 +1552,12 @@ func TestMcpManifest(t *testing.T) {
|
||||
Description: "a map of any",
|
||||
AdditionalProperties: true,
|
||||
},
|
||||
"foo-enum-string": {
|
||||
Type: "string",
|
||||
Description: "enum of strings",
|
||||
},
|
||||
},
|
||||
Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any"},
|
||||
Required: []string{"foo-string2", "foo-int2", "foo-float", "foo-array2", "foo-map-int", "foo-map-any", "foo-enum-string"},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -1455,7 +1599,7 @@ func TestFailParametersUnmarshal(t *testing.T) {
|
||||
"description": "this is a param for string",
|
||||
},
|
||||
},
|
||||
err: "parameter is missing 'type' field: %!w(<nil>)",
|
||||
err: "parameter is missing 'type' field",
|
||||
},
|
||||
{
|
||||
name: "common parameter missing description",
|
||||
@@ -1663,6 +1807,18 @@ func TestGetParams(t *testing.T) {
|
||||
in: map[string]any{},
|
||||
want: tools.ParamValues{},
|
||||
},
|
||||
{
|
||||
name: "enum with escape",
|
||||
params: tools.Parameters{
|
||||
tools.NewEnumParameter(tools.NewStringParameter("my_string_enum", "string of enums"), true, []any{"foo", "bar"}),
|
||||
},
|
||||
in: map[string]any{
|
||||
"my_string_enum": "foo",
|
||||
},
|
||||
want: tools.ParamValues{
|
||||
tools.ParamValue{Name: "my_string_enum", Value: `"foo"`},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -1881,3 +2037,75 @@ func TestCheckParamRequired(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchStringOrRegex(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
input string
|
||||
target string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact string",
|
||||
input: "foo",
|
||||
target: "foo",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exact integer",
|
||||
input: fmt.Sprintf("%v", 5),
|
||||
target: fmt.Sprintf("%v", 5),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wrong integer",
|
||||
input: fmt.Sprintf("%v", 4),
|
||||
target: fmt.Sprintf("%v", 5),
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "exact boolean",
|
||||
input: fmt.Sprintf("%v", true),
|
||||
target: fmt.Sprintf("%v", true),
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "target contains input",
|
||||
input: "foo",
|
||||
target: "foo bar",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "regex any string",
|
||||
input: "foo",
|
||||
target: ".*",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex",
|
||||
input: "foo6",
|
||||
target: `foo\d+`,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex of numbers",
|
||||
input: "4",
|
||||
target: "^[1-5]$",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regex of numbers invalid",
|
||||
input: "7",
|
||||
target: "^[1-5]$",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := tools.MatchStringOrRegex(tc.input, tc.target)
|
||||
if got != tc.want {
|
||||
t.Fatalf("got %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
284
tests/cassandra/cassandra_integration_test.go
Normal file
284
tests/cassandra/cassandra_integration_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
CassandraSourceKind = "cassandra"
|
||||
CassandraToolKind = "cassandra-cql"
|
||||
Hosts = os.Getenv("CASSANDRA_HOST")
|
||||
Keyspace = "example_keyspace"
|
||||
Username = os.Getenv("CASSANDRA_USER")
|
||||
Password = os.Getenv("CASSANDRA_PASS")
|
||||
)
|
||||
|
||||
func getCassandraVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case Hosts:
|
||||
t.Fatal("'Hosts' not set")
|
||||
case Username:
|
||||
t.Fatal("'Username' not set")
|
||||
case Password:
|
||||
t.Fatal("'Password' not set")
|
||||
}
|
||||
return map[string]any{
|
||||
"kind": CassandraSourceKind,
|
||||
"hosts": strings.Split(Hosts, ","),
|
||||
"keyspace": Keyspace,
|
||||
"username": Username,
|
||||
"password": Password,
|
||||
}
|
||||
}
|
||||
|
||||
func initCassandraSession() (*gocql.Session, error) {
|
||||
hostStrings := strings.Split(Hosts, ",")
|
||||
|
||||
var hosts []string
|
||||
for _, h := range hostStrings {
|
||||
trimmedHost := strings.TrimSpace(h)
|
||||
if trimmedHost != "" {
|
||||
hosts = append(hosts, trimmedHost)
|
||||
}
|
||||
}
|
||||
if len(hosts) == 0 {
|
||||
return nil, fmt.Errorf("no valid hosts found in CASSANDRA_HOSTS env var")
|
||||
}
|
||||
// Configure cluster connection
|
||||
cluster := gocql.NewCluster(hosts...)
|
||||
cluster.Consistency = gocql.Quorum
|
||||
cluster.ProtoVersion = 4
|
||||
cluster.DisableInitialHostLookup = true
|
||||
cluster.ConnectTimeout = 10 * time.Second
|
||||
cluster.NumConns = 2
|
||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
Username: Username,
|
||||
Password: Password,
|
||||
}
|
||||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
||||
NumRetries: 3,
|
||||
Min: 200 * time.Millisecond,
|
||||
Max: 2 * time.Second,
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
// Create keyspace
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
CREATE KEYSPACE IF NOT EXISTS %s
|
||||
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}
|
||||
`, Keyspace)).Exec()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to create keyspace: %v", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func initTable(tableName string, session *gocql.Session) error {
|
||||
|
||||
// Create table with additional columns
|
||||
err := session.Query(fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s.%s (
|
||||
id int PRIMARY KEY,
|
||||
name text,
|
||||
email text,
|
||||
age int,
|
||||
is_active boolean,
|
||||
created_at timestamp
|
||||
)
|
||||
`, Keyspace, tableName)).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create table: %v", err)
|
||||
}
|
||||
|
||||
// Use fixed timestamps for reproducibility
|
||||
fixedTime, _ := time.Parse(time.RFC3339, "2025-07-25T12:00:00Z")
|
||||
dayAgo := fixedTime.Add(-24 * time.Hour)
|
||||
twelveHoursAgo := fixedTime.Add(-12 * time.Hour)
|
||||
|
||||
// Insert minimal diverse data with fixed time.Time for timestamps
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
3, "Alice", tests.ServiceAccountEmail, 25, true, dayAgo,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
2, "Alex", "janedoe@gmail.com", 30, false, twelveHoursAgo,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
1, "Sid", "sid@gmail.com", 10, true, fixedTime,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
err = session.Query(fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (id, name,email, age, is_active, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`, Keyspace, tableName),
|
||||
4, nil, "a@gmail.com", 40, false, fixedTime,
|
||||
).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to insert user: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func dropTable(session *gocql.Session, tableName string) {
|
||||
err := session.Query(fmt.Sprintf("drop table %s.%s", Keyspace, tableName)).Exec()
|
||||
if err != nil {
|
||||
log.Printf("Failed to drop table %s: %v", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCassandra(t *testing.T) {
|
||||
session, err := initCassandraSession()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer session.Close()
|
||||
sourceConfig := getCassandraVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
paramTableName := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
err = initTable(paramTableName, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, paramTableName)
|
||||
|
||||
err = initTable(tableNameAuth, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, tableNameAuth)
|
||||
|
||||
err = initTable(tableNameTemplateParam, session)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dropTable(session, tableNameTemplateParam)
|
||||
|
||||
paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt := createParamToolInfo(paramTableName)
|
||||
_, _, authToolStmt := getCassandraAuthToolInfo(tableNameAuth)
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, CassandraToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getCassandraTmplToolInfo()
|
||||
tmpSelectAll := "SELECT * FROM {{.tableName}} where id = 1"
|
||||
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CassandraToolKind, tmplSelectCombined, tmplSelectFilterCombined, tmpSelectAll)
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, mcpSelect1Want, mcpMyToolIdWant := getCassandraWants()
|
||||
selectAllWant, selectIdWant, selectNameWant := getCassandraTmplWants()
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test(),
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolId3NameAliceWant(selectIdNameWant),
|
||||
tests.WithMyToolById4Want(selectIdNullWant),
|
||||
tests.WithMyArrayToolWant(selectArrayParamWant),
|
||||
tests.DisableSelect1AuthTest())
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam,
|
||||
tests.DisableSelectFilterTest(),
|
||||
tests.WithSelectAllWant(selectAllWant),
|
||||
tests.DisableDdlTest(), tests.DisableInsertTest(), tests.WithTmplSelectId1Want(selectIdWant), tests.WithTmplSelectNameWant(selectNameWant))
|
||||
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want,
|
||||
tests.WithMcpMyToolId3NameAliceWant(mcpMyToolIdWant),
|
||||
tests.DisableMcpSelect1AuthTest())
|
||||
|
||||
}
|
||||
|
||||
func createParamToolInfo(tableName string) (string, string, string, string) {
|
||||
toolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id = ? AND name = ? ALLOW FILTERING;", tableName)
|
||||
idParamStatement := fmt.Sprintf("SELECT id,name FROM %s WHERE id = ?;", tableName)
|
||||
nameParamStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE name = ? ALLOW FILTERING;", tableName)
|
||||
arrayToolStatement := fmt.Sprintf("SELECT id, name FROM %s WHERE id IN ? AND name IN ? ALLOW FILTERING;", tableName)
|
||||
return toolStatement, idParamStatement, nameParamStatement, arrayToolStatement
|
||||
|
||||
}
|
||||
|
||||
func getCassandraAuthToolInfo(tableName string) (string, string, string) {
|
||||
createStatement := fmt.Sprintf("CREATE TABLE %s (id UUID PRIMARY KEY, name TEXT, email TEXT);", tableName)
|
||||
insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (uuid(), ?, ?), (uuid(), ?, ?);", tableName)
|
||||
toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = ? ALLOW FILTERING;", tableName)
|
||||
return createStatement, insertStatement, toolStatement
|
||||
}
|
||||
|
||||
func getCassandraTmplToolInfo() (string, string) {
|
||||
selectAllTemplateStmt := "SELECT age, id, name FROM {{.tableName}} where id = ?;"
|
||||
selectByIdTemplateStmt := "SELECT id, name FROM {{.tableName}} WHERE name = ? ALLOW FILTERING;"
|
||||
return selectAllTemplateStmt, selectByIdTemplateStmt
|
||||
}
|
||||
|
||||
func getCassandraWants() (string, string, string, string, string, string) {
|
||||
selectIdNameWant := "[{\"id\":3,\"name\":\"Alice\"}]"
|
||||
selectIdNullWant := "[{\"id\":4,\"name\":\"\"}]"
|
||||
selectArrayParamWant := "[{\"id\":1,\"name\":\"Sid\"},{\"id\":3,\"name\":\"Alice\"}]"
|
||||
mcpMyFailToolWant := "{\"jsonrpc\":\"2.0\",\"id\":\"invoke-fail-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"unable to parse rows: line 1:0 no viable alternative at input 'SELEC' ([SELEC]...)\"}],\"isError\":true}}"
|
||||
mcpMyToolIdWant := "{\"jsonrpc\":\"2.0\",\"id\":\"my-tool\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"[{\\\"id\\\":3,\\\"name\\\":\\\"Alice\\\"}]\"}]}}"
|
||||
return selectIdNameWant, selectIdNullWant, selectArrayParamWant, mcpMyFailToolWant, "nil", mcpMyToolIdWant
|
||||
}
|
||||
|
||||
func getCassandraTmplWants() (string, string, string) {
|
||||
selectAllWant := "[{\"age\":10,\"created_at\":\"2025-07-25T12:00:00Z\",\"email\":\"sid@gmail.com\",\"id\":1,\"is_active\":true,\"name\":\"Sid\"}]"
|
||||
selectIdWant := "[{\"age\":10,\"id\":1,\"name\":\"Sid\"}]"
|
||||
selectNameWant := "[{\"id\":2,\"name\":\"Alex\"}]"
|
||||
return selectAllWant, selectIdWant, selectNameWant
|
||||
}
|
||||
@@ -110,6 +110,7 @@ func TestMongoDBToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(myToolId3NameAliceWant),
|
||||
tests.WithMyArrayToolWant(myToolId3NameAliceWant),
|
||||
tests.WithMyToolById4Want(myToolById4Want),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, select1Want,
|
||||
|
||||
@@ -21,9 +21,12 @@ type InvokeTestConfig struct {
|
||||
myToolId3NameAliceWant string
|
||||
myToolById4Want string
|
||||
nullWant string
|
||||
myArrayToolWant string
|
||||
supportSelect1Want bool
|
||||
supportOptionalNullParam bool
|
||||
supportArrayParam bool
|
||||
supportClientAuth bool
|
||||
supportSelect1Auth bool
|
||||
}
|
||||
|
||||
type InvokeTestOption func(*InvokeTestConfig)
|
||||
@@ -36,6 +39,14 @@ func WithMyToolId3NameAliceWant(s string) InvokeTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithMyArrayToolWant represents the response value for my-array-tool.
|
||||
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyArrayToolWant("custom"))
|
||||
func WithMyArrayToolWant(s string) InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.myArrayToolWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithMyToolById4Want represents the response value for my-tool-by-id with id=4.
|
||||
// This response includes a null value column.
|
||||
// e.g. tests.RunToolInvokeTest(t, select1Want, tests.WithMyToolById4Want("custom"))
|
||||
@@ -69,6 +80,22 @@ func DisableArrayTest() InvokeTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// DisableSelect1Test disables tests for sources that do not support SELECT 1 query.
|
||||
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1Test())
|
||||
func DisableSelect1Test() InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.supportSelect1Want = false
|
||||
}
|
||||
}
|
||||
|
||||
// DisableSelect1AuthTest disables auth tests for sources that do not support SELECT 1 query.
|
||||
// e.g. tests.RunToolInvokeTest(t, "", tests.DisableSelect1AuthTest())
|
||||
func DisableSelect1AuthTest() InvokeTestOption {
|
||||
return func(c *InvokeTestConfig) {
|
||||
c.supportSelect1Auth = false
|
||||
}
|
||||
}
|
||||
|
||||
// EnableClientAuthTest runs the client authorization tests.
|
||||
// Only enable it if your source supports the `useClientOAuth` configuration.
|
||||
// Currently, this should only be used with the BigQuery tests.
|
||||
@@ -84,6 +111,7 @@ func EnableClientAuthTest() InvokeTestOption {
|
||||
type MCPTestConfig struct {
|
||||
myToolId3NameAliceWant string
|
||||
supportClientAuth bool
|
||||
supportSelect1Auth bool
|
||||
}
|
||||
|
||||
type McpTestOption func(*MCPTestConfig)
|
||||
@@ -105,6 +133,13 @@ func EnableMcpClientAuthTest() McpTestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// DisableMcpSelect1AuthTest disables the auth tool tests which use select 1.
|
||||
func DisableMcpSelect1AuthTest() McpTestOption {
|
||||
return func(c *MCPTestConfig) {
|
||||
c.supportSelect1Auth = false
|
||||
}
|
||||
}
|
||||
|
||||
/* Configurations for RunExecuteSqlToolInvokeTest() */
|
||||
|
||||
// ExecuteSqlTestConfig represents the various configuration options for RunExecuteSqlToolInvokeTest()
|
||||
@@ -129,6 +164,7 @@ type TemplateParameterTestConfig struct {
|
||||
ddlWant string
|
||||
selectAllWant string
|
||||
selectId1Want string
|
||||
selectNameWant string
|
||||
selectEmptyWant string
|
||||
insert1Want string
|
||||
|
||||
@@ -136,8 +172,9 @@ type TemplateParameterTestConfig struct {
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
|
||||
supportDdl bool
|
||||
supportInsert bool
|
||||
supportDdl bool
|
||||
supportInsert bool
|
||||
supportSelectFields bool
|
||||
}
|
||||
|
||||
type TemplateParamOption func(*TemplateParameterTestConfig)
|
||||
@@ -166,6 +203,14 @@ func WithTmplSelectId1Want(s string) TemplateParamOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithTmplSelectNameWant represents the response value of select-filter-templateParams-combined-tool with name.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithTmplSelectNameWant("custom"))
|
||||
func WithTmplSelectNameWant(s string) TemplateParamOption {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.selectNameWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectEmptyWant represents the response value of select-templateParams-combined-tool with no results.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithSelectEmptyWant("custom"))
|
||||
func WithSelectEmptyWant(s string) TemplateParamOption {
|
||||
@@ -221,3 +266,11 @@ func DisableInsertTest() TemplateParamOption {
|
||||
c.supportInsert = false
|
||||
}
|
||||
}
|
||||
|
||||
// DisableInsertTest disables tests of select-fields-templateParams-tool test.
|
||||
// e.g. tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.DisableSelectFilterTest())
|
||||
func DisableSelectFilterTest() TemplateParamOption {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.supportSelectFields = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,6 +104,7 @@ func TestRedisToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(invokeIdNullWant),
|
||||
tests.WithNullWant(nullWant),
|
||||
)
|
||||
|
||||
@@ -164,6 +164,7 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(toolInvokeMyToolById4Want),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want, tests.WithMcpMyToolId3NameAliceWant(mcpMyToolId3NameAliceWant))
|
||||
|
||||
@@ -257,10 +257,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
configs := &InvokeTestConfig{
|
||||
myToolId3NameAliceWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
||||
myToolById4Want: "[{\"id\":4,\"name\":null}]",
|
||||
myArrayToolWant: "[{\"id\":1,\"name\":\"Alice\"},{\"id\":3,\"name\":\"Sid\"}]",
|
||||
nullWant: "null",
|
||||
supportOptionalNullParam: true,
|
||||
supportArrayParam: true,
|
||||
supportClientAuth: false,
|
||||
supportSelect1Want: true,
|
||||
supportSelect1Auth: true,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -294,7 +297,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Want,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: select1Want,
|
||||
@@ -351,13 +354,13 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
enabled: configs.supportArrayParam,
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"idArray": [1,2,3], "nameArray": ["Alice", "Sid", "RandomName"], "cmdArray": ["HGETALL", "row3"]}`)),
|
||||
wantBody: configs.myToolId3NameAliceWant,
|
||||
wantBody: configs.myArrayToolWant,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantBody: "[{\"name\":\"Alice\"}]",
|
||||
@@ -366,7 +369,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": "INVALID_TOKEN"},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
wantStatusCode: http.StatusUnauthorized,
|
||||
@@ -382,7 +385,7 @@ func RunToolInvokeTest(t *testing.T, select1Want string, options ...InvokeTestOp
|
||||
{
|
||||
name: "Invoke my-auth-required-tool with auth token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-required-tool/invoke",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
|
||||
@@ -491,6 +494,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
ddlWant: "null",
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
selectId1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectNameWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectEmptyWant: "null",
|
||||
insert1Want: "null",
|
||||
|
||||
@@ -512,6 +516,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
ddl bool
|
||||
insert bool
|
||||
api string
|
||||
@@ -573,6 +578,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
},
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
enabled: configs.supportSelectFields,
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "fields":%s}`, tableName, configs.nameFieldArray))),
|
||||
@@ -584,7 +590,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
api: "http://127.0.0.1:5000/api/tool/select-filter-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"name": "Alex", "tableName": "%s", "columnFilter": "%s"}`, tableName, configs.nameColFilter))),
|
||||
want: configs.selectId1Want,
|
||||
want: configs.selectNameWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -599,6 +605,9 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, options
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if !tc.enabled {
|
||||
return
|
||||
}
|
||||
// if test case is DDL and source support ddl test cases
|
||||
ddlAllow := !tc.ddl || (tc.ddl && configs.supportDdl)
|
||||
// if test case is insert statement and source support insert test cases
|
||||
@@ -834,6 +843,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
|
||||
configs := &MCPTestConfig{
|
||||
myToolId3NameAliceWant: `{"jsonrpc":"2.0","id":"my-tool","result":{"content":[{"type":"text","text":"{\"id\":1,\"name\":\"Alice\"}"},{"type":"text","text":"{\"id\":3,\"name\":\"Sid\"}"}]}}`,
|
||||
supportClientAuth: false,
|
||||
supportSelect1Auth: true,
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -947,7 +957,7 @@ func RunMCPToolCallMethod(t *testing.T, myFailToolWant, select1Want string, opti
|
||||
{
|
||||
name: "MCP Invoke my-auth-required-tool",
|
||||
api: "http://127.0.0.1:5000/mcp",
|
||||
enabled: true,
|
||||
enabled: configs.supportSelect1Auth,
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
|
||||
@@ -107,6 +107,7 @@ func TestValkeyToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.WithMyToolId3NameAliceWant(invokeParamWant),
|
||||
tests.WithMyArrayToolWant(invokeParamWant),
|
||||
tests.WithMyToolById4Want(invokeIdNullWant),
|
||||
tests.WithNullWant(nullWant),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user