mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-14 01:48:29 -05:00
Compare commits
6 Commits
update
...
ci/sample-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e843f73079 | ||
|
|
6eaf36ac85 | ||
|
|
051e686476 | ||
|
|
9af55b651d | ||
|
|
5fa1660fc8 | ||
|
|
d9ee17d2c7 |
@@ -486,7 +486,7 @@ steps:
|
||||
"Looker" \
|
||||
looker \
|
||||
looker
|
||||
|
||||
|
||||
- id: "duckdb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
@@ -506,7 +506,6 @@ steps:
|
||||
duckdb \
|
||||
duckdb
|
||||
|
||||
|
||||
- id: "alloydbwaitforoperation"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
@@ -525,6 +524,28 @@ steps:
|
||||
"Alloydb Wait for Operation" \
|
||||
utility \
|
||||
utility/alloydbwaitforoperation
|
||||
|
||||
- id: "tidb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "TIDB_DATABASE=$_DATABASE_NAME"
|
||||
- "TIDB_HOST=$_TIDB_HOST"
|
||||
- "TIDB_PORT=$_TIDB_PORT"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID", "TIDB_USER", "TIDB_PASS"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"TiDB" \
|
||||
tidb \
|
||||
tidbsql tidbexecutesql
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
@@ -584,7 +605,10 @@ availableSecrets:
|
||||
env: LOOKER_CLIENT_ID
|
||||
- versionName: projects/107716898620/secrets/looker_client_secret/versions/latest
|
||||
env: LOOKER_CLIENT_SECRET
|
||||
|
||||
- versionName: projects/107716898620/secrets/tidb_user/versions/latest
|
||||
env: TIDB_USER
|
||||
- versionName: projects/107716898620/secrets/tidb_pass/versions/latest
|
||||
env: TIDB_PASS
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
@@ -616,4 +640,6 @@ substitutions:
|
||||
_DGRAPHURL: "https://play.dgraph.io"
|
||||
_COUCHBASE_BUCKET: "couchbase-bucket"
|
||||
_COUCHBASE_SCOPE: "couchbase-scope"
|
||||
_LOOKER_VERIFY_SSL: "true"
|
||||
_LOOKER_VERIFY_SSL: "true"
|
||||
_TIDB_HOST: 127.0.0.1
|
||||
_TIDB_PORT: "4000"
|
||||
|
||||
20
cmd/BUILD
20
cmd/BUILD
@@ -1,20 +0,0 @@
|
||||
load("//tools/build_defs/go:go_library.bzl", "go_library")
|
||||
load("//tools/build_defs/go:go_test.bzl", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "cmd",
|
||||
srcs = [
|
||||
"options.go",
|
||||
"root.go",
|
||||
],
|
||||
embedsrcs = ["version.txt"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "cmd_test",
|
||||
srcs = [
|
||||
"options_test.go",
|
||||
"root_test.go",
|
||||
],
|
||||
library = ":cmd",
|
||||
)
|
||||
@@ -51,6 +51,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchentries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dgraph"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/duckdbsql"
|
||||
@@ -94,6 +95,8 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannerexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/spanner/spannersql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/sqlitesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/utility/alloydbwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/utility/wait"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/valkey"
|
||||
@@ -121,6 +124,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/redis"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/sqlite"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/sources/valkey"
|
||||
)
|
||||
|
||||
|
||||
@@ -1250,7 +1250,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"dataplex-tools": tools.ToolsetConfig{
|
||||
Name: "dataplex-tools",
|
||||
ToolNames: []string{"dataplex_search_entries"},
|
||||
ToolNames: []string{"dataplex_search_entries", "dataplex_lookup_entry"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -35,10 +35,213 @@ You can use the following system prompt as "Custom Instructions" in your client
|
||||
application.
|
||||
|
||||
```
|
||||
Whenever you will receive response from dataplex_search_entries tool decide what do to by following these steps:
|
||||
# Objective
|
||||
Your primary objective is to help discover, organize and manage metadata related to data assets.
|
||||
|
||||
# Tone and Style
|
||||
1. Adopt the persona of a senior subject matter expert
|
||||
2. Your communication style must be:
|
||||
1. Concise: Always favor brevity.
|
||||
2. Direct: Avoid greetings (e.g., "Hi there!", "Certainly!"). Get straight to the point.
|
||||
Example (Incorrect): Hi there! I see that you are looking for...
|
||||
Example (Correct): This problem likely stems from...
|
||||
3. Do not reiterate or summarize the question in the answer.
|
||||
4. Crucially, always convey a tone of uncertainty and caution. Since you are interpreting metadata and have no way to externally verify your answers, never express complete confidence. Frame your responses as interpretations based solely on the provided metadata. Use a suggestive tone, not a prescriptive one:
|
||||
Example (Correct): "The entry describes..."
|
||||
Example (Correct): "According to catalog,..."
|
||||
Example (Correct): "Based on the metadata,..."
|
||||
Example (Correct): "Based on the search results,..."
|
||||
5. Do not make assumptions
|
||||
|
||||
# Data Model
|
||||
## Entries
|
||||
Entry represents a specific data asset. Entry acts as a metadata record for something that is managed by Catalog, such as:
|
||||
|
||||
- A BigQuery table or dataset
|
||||
- A Cloud Storage bucket or folder
|
||||
- An on-premises SQL table
|
||||
|
||||
## Aspects
|
||||
While the Entry itself is a container, the rich descriptive information about the asset (e.g., schema, data types, business descriptions, classifications) is stored in associated components called Aspects. Aspects are created based on pre-defined blueprints known as Aspect Types.
|
||||
|
||||
## Aspect Types
|
||||
Aspect Type is a reusable template that defines the schema for a set of metadata fields. Think of an Aspect Type as a structure for the kind of metadata that is organized in the catalog within the Entry.
|
||||
|
||||
Examples:
|
||||
- projects/dataplex-types/locations/global/aspectTypes/analytics-hub-exchange
|
||||
- projects/dataplex-types/locations/global/aspectTypes/analytics-hub
|
||||
- projects/dataplex-types/locations/global/aspectTypes/analytics-hub-listing
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-connection
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-data-policy
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-dataset
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-model
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-policy
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-routine
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-row-access-policy
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-table
|
||||
- projects/dataplex-types/locations/global/aspectTypes/bigquery-view
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-bigtable-instance
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-bigtable-table
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-spanner-database
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-spanner-instance
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-spanner-table
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloud-spanner-view
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloudsql-database
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloudsql-instance
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloudsql-schema
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloudsql-table
|
||||
- projects/dataplex-types/locations/global/aspectTypes/cloudsql-view
|
||||
- projects/dataplex-types/locations/global/aspectTypes/contacts
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataform-code-asset
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataform-repository
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataform-workspace
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataproc-metastore-database
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataproc-metastore-service
|
||||
- projects/dataplex-types/locations/global/aspectTypes/dataproc-metastore-table
|
||||
- projects/dataplex-types/locations/global/aspectTypes/data-product
|
||||
- projects/dataplex-types/locations/global/aspectTypes/data-quality-scorecard
|
||||
- projects/dataplex-types/locations/global/aspectTypes/external-connection
|
||||
- projects/dataplex-types/locations/global/aspectTypes/overview
|
||||
- projects/dataplex-types/locations/global/aspectTypes/pubsub-topic
|
||||
- projects/dataplex-types/locations/global/aspectTypes/schema
|
||||
- projects/dataplex-types/locations/global/aspectTypes/sensitive-data-protection-job-result
|
||||
- projects/dataplex-types/locations/global/aspectTypes/sensitive-data-protection-profile
|
||||
- projects/dataplex-types/locations/global/aspectTypes/sql-access
|
||||
- projects/dataplex-types/locations/global/aspectTypes/storage-bucket
|
||||
- projects/dataplex-types/locations/global/aspectTypes/storage-folder
|
||||
- projects/dataplex-types/locations/global/aspectTypes/storage
|
||||
- projects/dataplex-types/locations/global/aspectTypes/usage
|
||||
|
||||
## Entry Types
|
||||
Every Entry must conform to an Entry Type. The Entry Type acts as a template, defining the structure, required aspects, and constraints for Entries of that type.
|
||||
|
||||
Examples:
|
||||
- projects/dataplex-types/locations/global/entryTypes/analytics-hub-exchange
|
||||
- projects/dataplex-types/locations/global/entryTypes/analytics-hub-listing
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-connection
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-data-policy
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-dataset
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-model
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-routine
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-row-access-policy
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/bigquery-view
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-bigtable-instance
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-bigtable-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-spanner-database
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-spanner-instance
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-spanner-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloud-spanner-view
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-mysql-database
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-mysql-instance
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-mysql-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-mysql-view
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-postgresql-database
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-postgresql-instance
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-postgresql-schema
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-postgresql-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-postgresql-view
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-sqlserver-database
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-sqlserver-instance
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-sqlserver-schema
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-sqlserver-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/cloudsql-sqlserver-view
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataform-code-asset
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataform-repository
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataform-workspace
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataproc-metastore-database
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataproc-metastore-service
|
||||
- projects/dataplex-types/locations/global/entryTypes/dataproc-metastore-table
|
||||
- projects/dataplex-types/locations/global/entryTypes/pubsub-topic
|
||||
- projects/dataplex-types/locations/global/entryTypes/storage-bucket
|
||||
- projects/dataplex-types/locations/global/entryTypes/storage-folder
|
||||
- projects/dataplex-types/locations/global/entryTypes/vertexai-dataset
|
||||
- projects/dataplex-types/locations/global/entryTypes/vertexai-feature-group
|
||||
- projects/dataplex-types/locations/global/entryTypes/vertexai-feature-online-store
|
||||
|
||||
## Entry Groups
|
||||
Entries are organized within Entry Groups, which are logical groupings of Entries. An Entry Group acts as a namespace for its Entries.
|
||||
|
||||
## Entry Links
|
||||
Entries can be linked together using EntryLinks to represent relationships between data assets (e.g. foreign keys).
|
||||
|
||||
# Tool instructions
|
||||
## Tool: dataplex_search_entries
|
||||
## General
|
||||
- Do not try to search within search results on your own.
|
||||
- Do not fetch multiple pages of results unless explicitly asked.
|
||||
|
||||
## Search syntax
|
||||
|
||||
### Simple search
|
||||
In its simplest form, a search query consists of a single predicate. Such a predicate can match several pieces of metadata:
|
||||
|
||||
- A substring of a name, display name, or description of a resource
|
||||
- A substring of the type of a resource
|
||||
- A substring of a column name (or nested column name) in the schema of a resource
|
||||
- A substring of a project ID
|
||||
- A string from an overview description
|
||||
|
||||
For example, the predicate foo matches the following resources:
|
||||
- Resource with the name foo.bar
|
||||
- Resource with the display name Foo Bar
|
||||
- Resource with the description This is the foo script
|
||||
- Resource with the exact type foo
|
||||
- Column foo_bar in the schema of a resource
|
||||
- Nested column foo_bar in the schema of a resource
|
||||
- Project prod-foo-bar
|
||||
- Resource with an overview containing the word foo
|
||||
|
||||
|
||||
### Qualified predicates
|
||||
You can qualify a predicate by prefixing it with a key that restricts the matching to a specific piece of metadata:
|
||||
- An equal sign (=) restricts the search to an exact match.
|
||||
- A colon (:) after the key matches the predicate to either a substring or a token within the value in the search results.
|
||||
|
||||
Tokenization splits the stream of text into a series of tokens, with each token usually corresponding to a single word. For example:
|
||||
- name:foo selects resources with names that contain the foo substring, like foo1 and barfoo.
|
||||
- description:foo selects resources with the foo token in the description, like bar and foo.
|
||||
- location=foo matches resources in a specified location with foo as the location name.
|
||||
|
||||
The predicate keys type, system, location, and orgid support only the exact match (=) qualifier, not the substring qualifier (:). For example, type=foo or orgid=number.
|
||||
|
||||
Search syntax supports the following qualifiers:
|
||||
- "name:x" - Matches x as a substring of the resource ID.
|
||||
- "displayname:x" - Match x as a substring of the resource display name.
|
||||
- "column:x" - Matches x as a substring of the column name (or nested column name) in the schema of the resource.
|
||||
- "description:x" - Matches x as a token in the resource description.
|
||||
- "label:bar" - Matches BigQuery resources that have a label (with some value) and the label key has bar as a substring.
|
||||
- "label=bar" - Matches BigQuery resources that have a label (with some value) and the label key equals bar as a string.
|
||||
- "label:bar:x" - Matches x as a substring in the value of a label with a key bar attached to a BigQuery resource.
|
||||
- "label=foo:bar" - Matches BigQuery resources where the key equals foo and the key value equals bar.
|
||||
- "label.foo=bar" - Matches BigQuery resources where the key equals foo and the key value equals bar.
|
||||
- "label.foo" - Matches BigQuery resources that have a label whose key equals foo as a string.
|
||||
- "type=TYPE" - Matches resources of a specific entry type or its type alias.
|
||||
- "projectid:bar" - Matches resources within Google Cloud projects that match bar as a substring in the ID.
|
||||
- "parent:x" - Matches x as a substring of the hierarchical path of a resource. The parent path is a fully_qualified_name of the parent resource.
|
||||
- "orgid=number" - Matches resources within a Google Cloud organization with the exact ID value of the number.
|
||||
- "system=SYSTEM" - Matches resources from a specified system. For example, system=bigquery matches BigQuery resources.
|
||||
- "location=LOCATION" - Matches resources in a specified location with an exact name. For example, location=us-central1 matches assets hosted in Iowa. BigQuery Omni assets support this qualifier by using the BigQuery Omni location name. For example, location=aws-us-east-1 matches BigQuery Omni assets in Northern Virginia.
|
||||
- "createtime" -
|
||||
Finds resources that were created within, before, or after a given date or time. For example "createtime:2019-01-01" matches resources created on 2019-01-01.
|
||||
- "updatetime" - Finds resources that were updated within, before, or after a given date or time. For example "updatetime>2019-01-01" matches resources updated after 2019-01-01.
|
||||
- "fully_qualified_name:x" - Matches x as a substring of fully_qualified_name.
|
||||
- "fully_qualified_name=x" - Matches x as fully_qualified_name.
|
||||
|
||||
### Logical operators
|
||||
A query can consist of several predicates with logical operators. If you don't specify an operator, logical AND is implied. For example, foo bar returns resources that match both predicate foo and predicate bar.
|
||||
Logical AND and logical OR are supported. For example, foo OR bar.
|
||||
|
||||
You can negate a predicate with a - (hyphen) or NOT prefix. For example, -name:foo returns resources with names that don't match the predicate foo.
|
||||
Logical operators aren't case-sensitive. For example, both or and OR are acceptable.
|
||||
|
||||
### Request
|
||||
1. Always try to rewrite the prompt using search syntax.
|
||||
|
||||
### Response
|
||||
1. If there are multiple search results found
|
||||
1.1. Present the list of search results
|
||||
1.2. Format the output in nested ordered list, for example:
|
||||
1. Present the list of search results
|
||||
2. Format the output in nested ordered list, for example:
|
||||
Given
|
||||
```
|
||||
{
|
||||
@@ -75,14 +278,19 @@ Whenever you will receive response from dataplex_search_entries tool decide what
|
||||
- location: us-central1
|
||||
- description: Table contains list of best customers.
|
||||
```
|
||||
1.3. Ask to select one of the presented search results
|
||||
3. Ask to select one of the presented search results
|
||||
2. If there is only one search result found
|
||||
2.1. Present the search result immediately.
|
||||
1. Present the search result immediately.
|
||||
3. If there are no search result found
|
||||
3.1. Explain that no search result was found
|
||||
3.2. Suggest to provide a more specific search query.
|
||||
1. Explain that no search result was found
|
||||
2. Suggest to provide a more specific search query.
|
||||
|
||||
Do not try to search within search results on your own.
|
||||
## Tool: dataplex_lookup_entry
|
||||
### Request
|
||||
1. Always try to limit the size of the response by specifying `aspect_types` parameter. Make sure to include to select view=CUSTOM when using aspect_types parameter.
|
||||
2. If you do not know the name of the entry, use `dataplex_search_entries` tool
|
||||
### Response
|
||||
1. Unless asked for a specific aspect, respond with all aspects attached to the entry.
|
||||
```
|
||||
|
||||
## Reference
|
||||
@@ -90,4 +298,4 @@ Do not try to search within search results on your own.
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|----------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "dataplex". |
|
||||
| project | string | true | Id of the GCP project used for quota and billing purposes (e.g. "my-project-id").|
|
||||
| project | string | true | ID of the GCP project used for quota and billing purposes (e.g. "my-project-id").|
|
||||
|
||||
81
docs/en/resources/sources/tidb.md
Normal file
81
docs/en/resources/sources/tidb.md
Normal file
@@ -0,0 +1,81 @@
|
||||
---
|
||||
title: "TiDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
TiDB is a distributed SQL database that combines the best of traditional RDBMS and NoSQL databases.
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
[TiDB][tidb-docs] is an open-source distributed SQL database that supports Hybrid Transactional and Analytical Processing (HTAP) workloads. It is MySQL-compatible and features horizontal scalability, strong consistency, and high availability.
|
||||
|
||||
[tidb-docs]: https://docs.pingcap.com/tidb/stable
|
||||
|
||||
## Requirements
|
||||
|
||||
### Database User
|
||||
|
||||
This source uses standard MySQL protocol authentication. You will need to [create a TiDB user][tidb-users] to login to the database with.
|
||||
|
||||
For TiDB Cloud users, you can create database users through the TiDB Cloud console.
|
||||
|
||||
[tidb-users]: https://docs.pingcap.com/tidb/stable/user-account-management
|
||||
|
||||
## SSL Configuration
|
||||
|
||||
- TiDB Cloud
|
||||
|
||||
For TiDB Cloud instances, SSL is automatically enabled when the hostname matches the TiDB Cloud pattern (`gateway*.*.*.tidbcloud.com`). You don't need to explicitly set `ssl: true` for TiDB Cloud connections.
|
||||
|
||||
- Self-Hosted TiDB
|
||||
|
||||
For self-hosted TiDB instances, you can optionally enable SSL by setting `ssl: true` in your configuration.
|
||||
|
||||
## Example
|
||||
|
||||
- TiDB Cloud
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-tidb-cloud-source:
|
||||
kind: tidb
|
||||
host: gateway01.us-west-2.prod.aws.tidbcloud.com
|
||||
port: 4000
|
||||
database: my_db
|
||||
user: ${TIDB_USERNAME}
|
||||
password: ${TIDB_PASSWORD}
|
||||
# SSL is automatically enabled for TiDB Cloud
|
||||
```
|
||||
|
||||
- Self-Hosted TiDB
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-tidb-source:
|
||||
kind: tidb
|
||||
host: 127.0.0.1
|
||||
port: 4000
|
||||
database: my_db
|
||||
user: ${TIDB_USERNAME}
|
||||
password: ${TIDB_PASSWORD}
|
||||
# ssl: true # Optional: enable SSL for secure connections
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|--------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "tidb". |
|
||||
| host | string | true | IP address or hostname to connect to (e.g. "127.0.0.1" or "gateway01.*.tidbcloud.com"). |
|
||||
| port | string | true | Port to connect to (typically "4000" for TiDB). |
|
||||
| database | string | true | Name of the TiDB database to connect to (e.g. "my_db"). |
|
||||
| user | string | true | Name of the TiDB user to connect as (e.g. "my-tidb-user"). |
|
||||
| password | string | true | Password of the TiDB user (e.g. "my-password"). |
|
||||
| ssl | boolean | false | Whether to use SSL/TLS encryption. Automatically enabled for TiDB Cloud instances. |
|
||||
60
docs/en/resources/tools/dataplex/dataplex-lookup-entry.md
Normal file
60
docs/en/resources/tools/dataplex/dataplex-lookup-entry.md
Normal file
@@ -0,0 +1,60 @@
|
||||
---
|
||||
title: "dataplex-lookup-entry"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "dataplex-lookup-entry" tool returns details of a particular entry in Dataplex Catalog.
|
||||
aliases:
|
||||
- /resources/tools/dataplex-lookup-entry
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `dataplex-lookup-entry` tool returns details of a particular entry in Dataplex Catalog.
|
||||
It's compatible with the following sources:
|
||||
|
||||
- [dataplex](../sources/dataplex.md)
|
||||
|
||||
`dataplex-lookup-entry` takes a required `name` parameter which contains the project and location to which the request should be attributed in the following form: projects/{project}/locations/{location} and also a required `entry` parameter which is the resource name of the entry in the following form: projects/{project}/locations/{location}/entryGroups/{entryGroup}/entries/{entry}. It also optionally accepts following parameters:
|
||||
- `view` - View to control which parts of an entry the service should return. It takes integer values from 1-4 corresponding to type of view - BASIC, FULL, CUSTOM, ALL
|
||||
- `aspectTypes` - Limits the aspects returned to the provided aspect types in the format `projects/{project}/locations/{location}/aspectTypes/{aspectType}`. It only works for CUSTOM view.
|
||||
- `paths` - Limits the aspects returned to those associated with the provided paths within the Entry. It only works for CUSTOM view.
|
||||
|
||||
## Requirements
|
||||
|
||||
### IAM Permissions
|
||||
|
||||
Dataplex uses [Identity and Access Management (IAM)][iam-overview] to control
|
||||
user and group access to Dataplex resources. Toolbox will use your
|
||||
[Application Default Credentials (ADC)][adc] to authorize and authenticate when
|
||||
interacting with [Dataplex][dataplex-docs].
|
||||
|
||||
In addition to [setting the ADC for your server][set-adc], you need to ensure
|
||||
the IAM identity has been given the correct IAM permissions for the tasks you
|
||||
intend to perform. See [Dataplex Universal Catalog IAM permissions][iam-permissions]
|
||||
and [Dataplex Universal Catalog IAM roles][iam-roles] for more information on
|
||||
applying IAM permissions and roles to an identity.
|
||||
|
||||
[iam-overview]: https://cloud.google.com/dataplex/docs/iam-and-access-control
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
[iam-permissions]: https://cloud.google.com/dataplex/docs/iam-permissions
|
||||
[iam-roles]: https://cloud.google.com/dataplex/docs/iam-roles
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
lookup_entry:
|
||||
kind: dataplex-lookup-entry
|
||||
source: my-dataplex-source
|
||||
description: Use this tool to retrieve a specific entry in Dataplex Catalog.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "dataplex-lookup-entry". |
|
||||
| source | string | true | Name of the source the tool should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
7
docs/en/resources/tools/tidb/_index.md
Normal file
7
docs/en/resources/tools/tidb/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "TiDB"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with TiDB Sources, such as TiDB Cloud and self-hosted TiDB.
|
||||
---
|
||||
41
docs/en/resources/tools/tidb/tidb-execute-sql.md
Normal file
41
docs/en/resources/tools/tidb/tidb-execute-sql.md
Normal file
@@ -0,0 +1,41 @@
|
||||
---
|
||||
title: "tidb-execute-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "tidb-execute-sql" tool executes a SQL statement against a TiDB
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/tidb-execute-sql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `tidb-execute-sql` tool executes a SQL statement against a TiDB
|
||||
database. It's compatible with the following source:
|
||||
|
||||
- [tidb](../sources/tidb.md)
|
||||
|
||||
`tidb-execute-sql` takes one input parameter `sql` and run the sql
|
||||
statement against the `source`.
|
||||
|
||||
> **Note:** This tool is intended for developer assistant workflows with
|
||||
> human-in-the-loop and shouldn't be used for production agents.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
execute_sql_tool:
|
||||
kind: tidb-execute-sql
|
||||
source: my-tidb-instance
|
||||
description: Use this tool to execute sql statement.
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "tidb-execute-sql". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
105
docs/en/resources/tools/tidb/tidb-sql.md
Normal file
105
docs/en/resources/tools/tidb/tidb-sql.md
Normal file
@@ -0,0 +1,105 @@
|
||||
---
|
||||
title: "tidb-sql"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "tidb-sql" tool executes a pre-defined SQL statement against a TiDB
|
||||
database.
|
||||
aliases:
|
||||
- /resources/tools/tidb-sql
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
A `tidb-sql` tool executes a pre-defined SQL statement against a TiDB
|
||||
database. It's compatible with the following source:
|
||||
|
||||
- [tidb](../sources/tidb.md)
|
||||
|
||||
The specified SQL statement is executed as a [prepared statement][tidb-prepare],
|
||||
and expects parameters in the SQL query to be in the form of placeholders `?`.
|
||||
|
||||
[tidb-prepare]: https://docs.pingcap.com/tidb/stable/sql-prepared-plan-cache
|
||||
|
||||
## Example
|
||||
|
||||
> **Note:** This tool uses parameterized queries to prevent SQL injections.
|
||||
> Query parameters can be used as substitutes for arbitrary expressions.
|
||||
> Parameters cannot be used as substitutes for identifiers, column names, table
|
||||
> names, or other parts of the query.
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_flights_by_number:
|
||||
kind: tidb-sql
|
||||
source: my-tidb-instance
|
||||
statement: |
|
||||
SELECT * FROM flights
|
||||
WHERE airline = ?
|
||||
AND flight_number = ?
|
||||
LIMIT 10
|
||||
description: |
|
||||
Use this tool to get information for a specific flight.
|
||||
Takes an airline code and flight number and returns info on the flight.
|
||||
Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number.
|
||||
A airline code is a code for an airline service consisting of two-character
|
||||
airline designator and followed by flight number, which is 1 to 4 digit number.
|
||||
For example, if given CY 0123, the airline is "CY", and flight_number is "123".
|
||||
Another example for this is DL 1234, the airline is "DL", and flight_number is "1234".
|
||||
If the tool returns more than one option choose the date closes to today.
|
||||
Example:
|
||||
{{
|
||||
"airline": "CY",
|
||||
"flight_number": "888",
|
||||
}}
|
||||
Example:
|
||||
{{
|
||||
"airline": "DL",
|
||||
"flight_number": "1234",
|
||||
}}
|
||||
parameters:
|
||||
- name: airline
|
||||
type: string
|
||||
description: Airline unique 2 letter identifier
|
||||
- name: flight_number
|
||||
type: string
|
||||
description: 1 to 4 digit number
|
||||
```
|
||||
|
||||
### Example with Template Parameters
|
||||
|
||||
> **Note:** This tool allows direct modifications to the SQL statement,
|
||||
> including identifiers, column names, and table names. **This makes it more
|
||||
> vulnerable to SQL injections**. Using basic parameters only (see above) is
|
||||
> recommended for performance and safety reasons. For more details, please check
|
||||
> [templateParameters](_index#template-parameters).
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_table:
|
||||
kind: tidb-sql
|
||||
source: my-tidb-instance
|
||||
statement: |
|
||||
SELECT * FROM {{.tableName}};
|
||||
description: |
|
||||
Use this tool to list all information from a specific table.
|
||||
Example:
|
||||
{{
|
||||
"tableName": "flights",
|
||||
}}
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: Table to select from
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|--------------------|:------------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "tidb-sql". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
| statement | string | true | SQL statement to execute on. |
|
||||
| parameters | [parameters](_index#specifying-parameters) | false | List of [parameters](_index#specifying-parameters) that will be inserted into the SQL statement. |
|
||||
| templateParameters | [templateParameters](_index#template-parameters) | false | List of [templateParameters](_index#template-parameters) that will be inserted into the SQL statement before executing prepared statement. |
|
||||
@@ -7,9 +7,13 @@ tools:
|
||||
dataplex_search_entries:
|
||||
kind: dataplex-search-entries
|
||||
source: dataplex-source
|
||||
description: |
|
||||
Use this tool to search for entries in Dataplex Catalog that represent data assets (e.g. tables, views, models) based on the provided search query.
|
||||
description: Use this tool to search for entries in Dataplex Catalog based on the provided search query.
|
||||
dataplex_lookup_entry:
|
||||
kind: dataplex-lookup-entry
|
||||
source: dataplex-source
|
||||
description: Use this tool to retrieve a specific entry from Dataplex Catalog.
|
||||
|
||||
toolsets:
|
||||
dataplex-tools:
|
||||
- dataplex_search_entries
|
||||
- dataplex_search_entries
|
||||
- dataplex_lookup_entry
|
||||
@@ -214,7 +214,7 @@ func (s *stdioSession) write(ctx context.Context, response any) error {
|
||||
func mcpRouter(s *Server) (chi.Router, error) {
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(middleware.AllowContentType("application/json"))
|
||||
r.Use(middleware.AllowContentType("application/json", "application/json-rpc", "application/jsonrequest"))
|
||||
r.Use(middleware.StripSlashes)
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ func initDuckDbConnection(ctx context.Context, tracer trace.Tracer, name string,
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
var configStr string = getDuckDbConfiguration(dbFilePath, duckdbConfiguration)
|
||||
var configStr = getDuckDbConfiguration(dbFilePath, duckdbConfiguration)
|
||||
|
||||
//Open database connection
|
||||
db, err := sql.Open("duckdb", configStr)
|
||||
|
||||
128
internal/sources/tidb/tidb.go
Normal file
128
internal/sources/tidb/tidb.go
Normal file
@@ -0,0 +1,128 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "tidb"
|
||||
const TiDBCloudHostPattern string = `gateway\d{2}\.(.+)\.(prod|dev|staging)\.(.+)\.tidbcloud\.com`
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the host is a TiDB Cloud instance, force to use SSL
|
||||
if IsTiDBCloudHost(actual.Host) {
|
||||
actual.UseSSL = true
|
||||
}
|
||||
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Host string `yaml:"host" validate:"required"`
|
||||
Port string `yaml:"port" validate:"required"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
Database string `yaml:"database" validate:"required"`
|
||||
UseSSL bool `yaml:"ssl"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
pool, err := initTiDBConnectionPool(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.UseSSL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
}
|
||||
|
||||
err = pool.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Pool: pool,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Pool *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) TiDBPool() *sql.DB {
|
||||
return s.Pool
|
||||
}
|
||||
|
||||
func IsTiDBCloudHost(host string) bool {
|
||||
pattern := `gateway\d{2}\.(.+)\.(prod|dev|staging)\.(.+)\.tidbcloud\.com`
|
||||
match, err := regexp.MatchString(pattern, host)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return match
|
||||
}
|
||||
|
||||
func initTiDBConnectionPool(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, useSSL bool) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&charset=utf8mb4&tls=%t", user, pass, host, port, dbname, useSSL)
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
pool, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
258
internal/sources/tidb/tidb_test.go
Normal file
258
internal/sources/tidb/tidb_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlTiDB(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-tidb-instance:
|
||||
kind: tidb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-tidb-instance": tidb.Config{
|
||||
Name: "my-tidb-instance",
|
||||
Kind: tidb.SourceKind,
|
||||
Host: "0.0.0.0",
|
||||
Port: "my-port",
|
||||
Database: "my_db",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
UseSSL: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with SSL enabled",
|
||||
in: `
|
||||
sources:
|
||||
my-tidb-cloud:
|
||||
kind: tidb
|
||||
host: gateway01.us-west-2.prod.aws.tidbcloud.com
|
||||
port: 4000
|
||||
database: test_db
|
||||
user: cloud_user
|
||||
password: cloud_pass
|
||||
ssl: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-tidb-cloud": tidb.Config{
|
||||
Name: "my-tidb-cloud",
|
||||
Kind: tidb.SourceKind,
|
||||
Host: "gateway01.us-west-2.prod.aws.tidbcloud.com",
|
||||
Port: "4000",
|
||||
Database: "test_db",
|
||||
User: "cloud_user",
|
||||
Password: "cloud_pass",
|
||||
UseSSL: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Change SSL enabled due to TiDB Cloud host",
|
||||
in: `
|
||||
sources:
|
||||
my-tidb-cloud:
|
||||
kind: tidb
|
||||
host: gateway01.us-west-2.prod.aws.tidbcloud.com
|
||||
port: 4000
|
||||
database: test_db
|
||||
user: cloud_user
|
||||
password: cloud_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-tidb-cloud": tidb.Config{
|
||||
Name: "my-tidb-cloud",
|
||||
Kind: tidb.SourceKind,
|
||||
Host: "gateway01.us-west-2.prod.aws.tidbcloud.com",
|
||||
Port: "4000",
|
||||
Database: "test_db",
|
||||
User: "cloud_user",
|
||||
Password: "cloud_pass",
|
||||
UseSSL: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-tidb-instance:
|
||||
kind: tidb
|
||||
host: 0.0.0.0
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
ssl: false
|
||||
foo: bar
|
||||
`,
|
||||
err: "unable to parse source \"my-tidb-instance\" as \"tidb\": [2:1] unknown field \"foo\"\n 1 | database: my_db\n> 2 | foo: bar\n ^\n 3 | host: 0.0.0.0\n 4 | kind: tidb\n 5 | password: my_pass\n 6 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required field",
|
||||
in: `
|
||||
sources:
|
||||
my-tidb-instance:
|
||||
kind: tidb
|
||||
port: my-port
|
||||
database: my_db
|
||||
user: my_user
|
||||
password: my_pass
|
||||
ssl: false
|
||||
`,
|
||||
err: "unable to parse source \"my-tidb-instance\" as \"tidb\": Key: 'Config.Host' Error:Field validation for 'Host' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTiDBCloudHost(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
host string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
desc: "valid TiDB Cloud host - ap-southeast-1",
|
||||
host: "gateway01.ap-southeast-1.prod.aws.tidbcloud.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "invalid TiDB Cloud host - wrong domain",
|
||||
host: "gateway01.ap-southeast-1.prod.aws.tdbcloud.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "local IP address",
|
||||
host: "127.0.0.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "valid TiDB Cloud host - us-west-2",
|
||||
host: "gateway01.us-west-2.prod.aws.tidbcloud.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "valid TiDB Cloud host - dev environment",
|
||||
host: "gateway02.eu-west-1.dev.aws.tidbcloud.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "valid TiDB Cloud host - staging environment",
|
||||
host: "gateway03.us-east-1.staging.aws.tidbcloud.com",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "invalid - wrong gateway format",
|
||||
host: "gateway1.us-west-2.prod.aws.tidbcloud.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid - missing environment",
|
||||
host: "gateway01.us-west-2.aws.tidbcloud.com",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid - wrong subdomain",
|
||||
host: "gateway01.us-west-2.prod.aws.tidbcloud.org",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid - localhost",
|
||||
host: "localhost",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid - private IP",
|
||||
host: "192.168.1.1",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid - empty string",
|
||||
host: "",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := tidb.IsTiDBCloudHost(tc.host)
|
||||
if got != tc.want {
|
||||
t.Fatalf("isTiDBCloudHost(%q) = %v, want %v", tc.host, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -135,25 +135,6 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
query := t.Client.Query(sql)
|
||||
query.Location = t.Client.Location
|
||||
|
||||
// This block handles Data Manipulation Language (DML) and Data Definition Language (DDL) statements.
|
||||
// These statements (e.g., INSERT, UPDATE, CREATE TABLE) do not return a row set.
|
||||
// Instead, we execute them as a job, wait for completion, and return a success
|
||||
// message, including the number of affected rows for DML operations.
|
||||
if statementType != "SELECT" {
|
||||
job, err := query.Run(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start DML/DDL job: %w", err)
|
||||
}
|
||||
status, err := job.Wait(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to wait for DML/DDL job to complete: %w", err)
|
||||
}
|
||||
if err := status.Err(); err != nil {
|
||||
return nil, fmt.Errorf("DML/DDL job failed with error: %w", err)
|
||||
}
|
||||
return "Operation completed successfully.", nil
|
||||
}
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
@@ -177,10 +158,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
if out == nil {
|
||||
// If the query returned any rows, return them directly.
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
// executes but returns zero rows.
|
||||
if statementType == "SELECT" {
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
return out, nil
|
||||
// This is the fallback for a successful query that doesn't return content.
|
||||
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
||||
// However, it is also possible that this was a query that was expected to return rows
|
||||
// but returned none, a case that we cannot distinguish here.
|
||||
return "Query executed successfully and returned no content.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -17,6 +17,7 @@ package bigquerysql
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
bigqueryapi "cloud.google.com/go/bigquery"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
bigqueryds "github.com/googleapis/genai-toolbox/internal/sources/bigquery"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
bigqueryrestapi "google.golang.org/api/bigquery/v2"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
@@ -45,6 +47,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
type compatibleSource interface {
|
||||
BigQueryClient() *bigqueryapi.Client
|
||||
BigQueryRestService() *bigqueryrestapi.Service
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
@@ -101,6 +104,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.BigQueryClient(),
|
||||
RestService: s.BigQueryRestService(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
@@ -117,15 +121,17 @@ type Tool struct {
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Client *bigqueryapi.Client
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Statement string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
namedArgs := make([]bigqueryapi.QueryParameter, 0, len(params))
|
||||
highLevelParams := make([]bigqueryapi.QueryParameter, 0, len(t.Parameters))
|
||||
lowLevelParams := make([]*bigqueryrestapi.QueryParameter, 0, len(t.Parameters))
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
@@ -136,14 +142,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
name := p.GetName()
|
||||
value := paramsMap[name]
|
||||
|
||||
// BigQuery's QueryParameter only accepts typed slices as input
|
||||
// This checks if the param is an array.
|
||||
// If yes, convert []any to typed slice (e.g []string, []int)
|
||||
switch arrayParam := p.(type) {
|
||||
case *tools.ArrayParameter:
|
||||
// This block for converting []any to typed slices is still necessary and correct.
|
||||
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
|
||||
arrayParamValue, ok := value.([]any)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any %w", name, err)
|
||||
return nil, fmt.Errorf("unable to convert parameter `%s` to []any", name)
|
||||
}
|
||||
itemType := arrayParam.GetItems().GetType()
|
||||
var err error
|
||||
@@ -153,22 +156,69 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(t.Statement, "@"+name) {
|
||||
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
} else {
|
||||
namedArgs = append(namedArgs, bigqueryapi.QueryParameter{
|
||||
Value: value,
|
||||
})
|
||||
// Determine if the parameter is named or positional for the high-level client.
|
||||
var paramNameForHighLevel string
|
||||
if strings.Contains(newStatement, "@"+name) {
|
||||
paramNameForHighLevel = name
|
||||
}
|
||||
|
||||
// 1. Create the high-level parameter for the final query execution.
|
||||
highLevelParams = append(highLevelParams, bigqueryapi.QueryParameter{
|
||||
Name: paramNameForHighLevel,
|
||||
Value: value,
|
||||
})
|
||||
|
||||
// 2. Create the low-level parameter for the dry run, using the defined type from `p`.
|
||||
lowLevelParam := &bigqueryrestapi.QueryParameter{
|
||||
Name: paramNameForHighLevel,
|
||||
ParameterType: &bigqueryrestapi.QueryParameterType{},
|
||||
ParameterValue: &bigqueryrestapi.QueryParameterValue{},
|
||||
}
|
||||
|
||||
if arrayParam, ok := p.(*tools.ArrayParameter); ok {
|
||||
// Handle array types based on their defined item type.
|
||||
lowLevelParam.ParameterType.Type = "ARRAY"
|
||||
itemType, err := BQTypeStringFromToolType(arrayParam.GetItems().GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lowLevelParam.ParameterType.ArrayType = &bigqueryrestapi.QueryParameterType{Type: itemType}
|
||||
|
||||
// Build the array values.
|
||||
sliceVal := reflect.ValueOf(value)
|
||||
arrayValues := make([]*bigqueryrestapi.QueryParameterValue, sliceVal.Len())
|
||||
for i := 0; i < sliceVal.Len(); i++ {
|
||||
arrayValues[i] = &bigqueryrestapi.QueryParameterValue{
|
||||
Value: fmt.Sprintf("%v", sliceVal.Index(i).Interface()),
|
||||
}
|
||||
}
|
||||
lowLevelParam.ParameterValue.ArrayValues = arrayValues
|
||||
} else {
|
||||
// Handle scalar types based on their defined type.
|
||||
bqType, err := BQTypeStringFromToolType(p.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lowLevelParam.ParameterType.Type = bqType
|
||||
lowLevelParam.ParameterValue.Value = fmt.Sprintf("%v", value)
|
||||
}
|
||||
lowLevelParams = append(lowLevelParams, lowLevelParam)
|
||||
}
|
||||
|
||||
query := t.Client.Query(newStatement)
|
||||
query.Parameters = namedArgs
|
||||
query.Parameters = highLevelParams
|
||||
query.Location = t.Client.Location
|
||||
|
||||
dryRunJob, err := dryRunQuery(ctx, t.RestService, t.Client.Project(), t.Client.Location, newStatement, lowLevelParams, query.ConnectionProperties)
|
||||
if err != nil {
|
||||
// This is a fallback check in case the switch logic was bypassed.
|
||||
return nil, fmt.Errorf("final query validation failed: %w", err)
|
||||
}
|
||||
statementType := dryRunJob.Statistics.Query.StatementType
|
||||
|
||||
// This block handles SELECT statements, which return a row set.
|
||||
// We iterate through the results, convert each row into a map of
|
||||
// column names to values, and return the collection of rows.
|
||||
it, err := query.Read(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
@@ -177,7 +227,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
var out []any
|
||||
for {
|
||||
var row map[string]bigqueryapi.Value
|
||||
err := it.Next(&row)
|
||||
err = it.Next(&row)
|
||||
if err == iterator.Done {
|
||||
break
|
||||
}
|
||||
@@ -190,8 +240,21 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error)
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
// If the query returned any rows, return them directly.
|
||||
if len(out) > 0 {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
return out, nil
|
||||
// This handles the standard case for a SELECT query that successfully
|
||||
// executes but returns zero rows.
|
||||
if statementType == "SELECT" {
|
||||
return "The query returned 0 rows.", nil
|
||||
}
|
||||
// This is the fallback for a successful query that doesn't return content.
|
||||
// In most cases, this will be for DML/DDL statements like INSERT, UPDATE, CREATE, etc.
|
||||
// However, it is also possible that this was a query that was expected to return rows
|
||||
// but returned none, a case that we cannot distinguish here.
|
||||
return "Query executed successfully and returned no content.", nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
@@ -209,3 +272,58 @@ func (t Tool) McpManifest() tools.McpManifest {
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func BQTypeStringFromToolType(toolType string) (string, error) {
|
||||
switch toolType {
|
||||
case "string":
|
||||
return "STRING", nil
|
||||
case "integer":
|
||||
return "INT64", nil
|
||||
case "float":
|
||||
return "FLOAT64", nil
|
||||
case "boolean":
|
||||
return "BOOL", nil
|
||||
// Note: 'array' is handled separately as it has a nested item type.
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported tool parameter type for BigQuery: %s", toolType)
|
||||
}
|
||||
}
|
||||
|
||||
func dryRunQuery(
|
||||
ctx context.Context,
|
||||
restService *bigqueryrestapi.Service,
|
||||
projectID string,
|
||||
location string,
|
||||
sql string,
|
||||
params []*bigqueryrestapi.QueryParameter,
|
||||
connProps []*bigqueryapi.ConnectionProperty,
|
||||
) (*bigqueryrestapi.Job, error) {
|
||||
useLegacySql := false
|
||||
|
||||
restConnProps := make([]*bigqueryrestapi.ConnectionProperty, len(connProps))
|
||||
for i, prop := range connProps {
|
||||
restConnProps[i] = &bigqueryrestapi.ConnectionProperty{Key: prop.Key, Value: prop.Value}
|
||||
}
|
||||
|
||||
jobToInsert := &bigqueryrestapi.Job{
|
||||
JobReference: &bigqueryrestapi.JobReference{
|
||||
ProjectId: projectID,
|
||||
Location: location,
|
||||
},
|
||||
Configuration: &bigqueryrestapi.JobConfiguration{
|
||||
DryRun: true,
|
||||
Query: &bigqueryrestapi.JobConfigurationQuery{
|
||||
Query: sql,
|
||||
UseLegacySql: &useLegacySql,
|
||||
ConnectionProperties: restConnProps,
|
||||
QueryParameters: params,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
insertResponse, err := restService.Jobs.Insert(projectID, jobToInsert).Context(ctx).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert dry run job: %w", err)
|
||||
}
|
||||
return insertResponse, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dataplexlookupentry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
dataplexapi "cloud.google.com/go/dataplex/apiv1"
|
||||
dataplexpb "cloud.google.com/go/dataplex/apiv1/dataplexpb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
dataplexds "github.com/googleapis/genai-toolbox/internal/sources/dataplex"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "dataplex-lookup-entry"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
CatalogClient() *dataplexapi.CatalogClient
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &dataplexds.Source{}
|
||||
|
||||
var compatibleSources = [...]string{dataplexds.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Initialize the search configuration with the provided sources
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
viewDesc := `
|
||||
## Argument: view
|
||||
|
||||
**Type:** Integer
|
||||
|
||||
**Description:** Specifies the parts of the entry and its aspects to return.
|
||||
|
||||
**Possible Values:**
|
||||
|
||||
* 1 (BASIC): Returns entry without aspects.
|
||||
* 2 (FULL): Return all required aspects and the keys of non-required aspects. (Default)
|
||||
* 3 (CUSTOM): Return the entry and aspects requested in aspect_types field (at most 100 aspects). Always use this view when aspect_types is not empty.
|
||||
* 4 (ALL): Return the entry and both required and optional aspects (at most 100 aspects)
|
||||
`
|
||||
|
||||
name := tools.NewStringParameter("name", "The project to which the request should be attributed in the following form: projects/{project}/locations/{location}.")
|
||||
view := tools.NewIntParameterWithDefault("view", 2, viewDesc)
|
||||
aspectTypes := tools.NewArrayParameterWithDefault("aspectTypes", []any{}, "Limits the aspects returned to the provided aspect types. It only works when used together with CUSTOM view.", tools.NewStringParameter("aspectType", "The types of aspects to be included in the response in the format `projects/{project}/locations/{location}/aspectTypes/{aspectType}`."))
|
||||
entry := tools.NewStringParameter("entry", "The resource name of the Entry in the following form: projects/{project}/locations/{location}/entryGroups/{entryGroup}/entries/{entry}.")
|
||||
parameters := tools.Parameters{name, view, aspectTypes, entry}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
t := &Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
CatalogClient: s.CatalogClient(),
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: parameters.Manifest(),
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Kind string
|
||||
Parameters tools.Parameters
|
||||
AuthRequired []string
|
||||
CatalogClient *dataplexapi.CatalogClient
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t *Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
viewMap := map[int]dataplexpb.EntryView{
|
||||
1: dataplexpb.EntryView_BASIC,
|
||||
2: dataplexpb.EntryView_FULL,
|
||||
3: dataplexpb.EntryView_CUSTOM,
|
||||
4: dataplexpb.EntryView_ALL,
|
||||
}
|
||||
name, _ := paramsMap["name"].(string)
|
||||
entry, _ := paramsMap["entry"].(string)
|
||||
view, _ := paramsMap["view"].(int)
|
||||
aspectTypeSlice, err := tools.ConvertAnySliceToTyped(paramsMap["aspectTypes"].([]any), "string")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can't convert aspectTypes to array of strings: %s", err)
|
||||
}
|
||||
aspectTypes := aspectTypeSlice.([]string)
|
||||
|
||||
req := &dataplexpb.LookupEntryRequest{
|
||||
Name: name,
|
||||
View: viewMap[view],
|
||||
AspectTypes: aspectTypes,
|
||||
Entry: entry,
|
||||
}
|
||||
|
||||
result, err := t.CatalogClient.LookupEntry(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
// Parse parameters from the provided data
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t *Tool) Manifest() tools.Manifest {
|
||||
// Returns the tool manifest
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t *Tool) McpManifest() tools.McpManifest {
|
||||
// Returns the tool MCP manifest
|
||||
return t.mcpManifest
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dataplexlookupentry_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||
)
|
||||
|
||||
func TestParseFromYamlDataplexLookupEntry(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: dataplex-lookup-entry
|
||||
source: my-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": dataplexlookupentry.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "dataplex-lookup-entry",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "advanced example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: dataplex-lookup-entry
|
||||
source: my-instance
|
||||
description: some description
|
||||
parameters:
|
||||
- name: name
|
||||
type: string
|
||||
description: some name description
|
||||
- name: view
|
||||
type: string
|
||||
description: some view description
|
||||
- name: aspectTypes
|
||||
type: array
|
||||
description: some aspect types description
|
||||
default: []
|
||||
items:
|
||||
name: aspectType
|
||||
type: string
|
||||
description: some aspect type description
|
||||
- name: entry
|
||||
type: string
|
||||
description: some entry description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": dataplexlookupentry.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "dataplex-lookup-entry",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("name", "some name description"),
|
||||
tools.NewStringParameter("view", "some view description"),
|
||||
tools.NewArrayParameterWithDefault("aspectTypes", []any{}, "some aspect types description", tools.NewStringParameter("aspectType", "some aspect type description")),
|
||||
tools.NewStringParameter("entry", "some entry description"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -80,12 +80,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
query := tools.NewStringParameter("query", "The query against which entries in scope should be matched.")
|
||||
name := tools.NewStringParameterWithDefault("name", fmt.Sprintf("projects/%s/locations/global", s.ProjectID()), "The project to which the request should be attributed in the following form: projects/{project}/locations/global")
|
||||
pageSize := tools.NewIntParameterWithDefault("pageSize", 5, "Number of results in the search page.")
|
||||
pageToken := tools.NewStringParameterWithDefault("pageToken", "", "Page token received from a previous locations.searchEntries call. Provide this to retrieve the subsequent page.")
|
||||
orderBy := tools.NewStringParameterWithDefault("orderBy", "relevance", "Specifies the ordering of results. Supported values are: relevance, last_modified_timestamp, last_modified_timestamp asc")
|
||||
semanticSearch := tools.NewBooleanParameterWithDefault("semanticSearch", true, "Whether to use semantic search for the query. If true, the query will be processed using semantic search capabilities.")
|
||||
parameters := tools.Parameters{query, name, pageSize, pageToken, orderBy, semanticSearch}
|
||||
parameters := tools.Parameters{query, pageSize, pageToken, orderBy, semanticSearch}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
@@ -93,7 +92,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
t := &SearchTool{
|
||||
t := &Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
@@ -110,7 +109,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type SearchTool struct {
|
||||
type Tool struct {
|
||||
Name string
|
||||
Kind string
|
||||
Parameters tools.Parameters
|
||||
@@ -121,14 +120,13 @@ type SearchTool struct {
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t *SearchTool) Authorized(verifiedAuthServices []string) bool {
|
||||
func (t *Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t *SearchTool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
query, _ := paramsMap["query"].(string)
|
||||
name, _ := paramsMap["name"].(string)
|
||||
pageSize, _ := paramsMap["pageSize"].(int32)
|
||||
pageToken, _ := paramsMap["pageToken"].(string)
|
||||
orderBy, _ := paramsMap["orderBy"].(string)
|
||||
@@ -136,7 +134,7 @@ func (t *SearchTool) Invoke(ctx context.Context, params tools.ParamValues) (any,
|
||||
|
||||
req := &dataplexpb.SearchEntriesRequest{
|
||||
Query: query,
|
||||
Name: name,
|
||||
Name: fmt.Sprintf("projects/%s/locations/global", t.ProjectID),
|
||||
PageSize: pageSize,
|
||||
PageToken: pageToken,
|
||||
OrderBy: orderBy,
|
||||
@@ -159,17 +157,17 @@ func (t *SearchTool) Invoke(ctx context.Context, params tools.ParamValues) (any,
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (t *SearchTool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
// Parse parameters from the provided data
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t *SearchTool) Manifest() tools.Manifest {
|
||||
func (t *Tool) Manifest() tools.Manifest {
|
||||
// Returns the tool manifest
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t *SearchTool) McpManifest() tools.McpManifest {
|
||||
func (t *Tool) McpManifest() tools.McpManifest {
|
||||
// Returns the tool MCP manifest
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
194
internal/tools/tidb/tidbexecutesql/tidbexecutesql.go
Normal file
194
internal/tools/tidb/tidbexecutesql/tidbexecutesql.go
Normal file
@@ -0,0 +1,194 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidbexecutesql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "tidb-execute-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
TiDBPool() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &tidb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{tidb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
sqlParameter := tools.NewStringParameter("sql", "The sql to execute.")
|
||||
parameters := tools.Parameters{sqlParameter}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.TiDBPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: parameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
sliceParams := params.AsSlice()
|
||||
sql, ok := sliceParams[0].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to get cast %s", sliceParams[0])
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR"
|
||||
// we'll need to cast it back to string
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[name] = string(val.([]byte))
|
||||
default:
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
76
internal/tools/tidb/tidbexecutesql/tidbexecutesql_test.go
Normal file
76
internal/tools/tidb/tidbexecutesql/tidbexecutesql_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidbexecutesql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlExecuteSql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: tidb-execute-sql
|
||||
source: my-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": tidbexecutesql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "tidb-execute-sql",
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
217
internal/tools/tidb/tidbsql/tidbsql.go
Normal file
217
internal/tools/tidb/tidbsql/tidbsql.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidbsql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/tidb"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const kind string = "tidb-sql"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
TiDBPool() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &tidb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{tidb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Statement string `yaml:"statement" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters, paramManifest, paramMcpManifest := tools.ProcessParameters(cfg.TemplateParameters, cfg.Parameters)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: paramMcpManifest,
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: cfg.Parameters,
|
||||
TemplateParameters: cfg.TemplateParameters,
|
||||
AllParams: allParameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.TiDBPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
TemplateParameters tools.Parameters `yaml:"templateParameters"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
newStatement, err := tools.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract template params %w", err)
|
||||
}
|
||||
|
||||
newParams, err := tools.GetParams(t.Parameters, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
|
||||
sliceParams := newParams.AsSlice()
|
||||
results, err := t.Pool.QueryContext(ctx, newStatement, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
cols, err := results.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve rows column name: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
colTypes, err := results.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get column types: %w", err)
|
||||
}
|
||||
|
||||
var out []any
|
||||
for results.Next() {
|
||||
err := results.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
val := rawValues[i]
|
||||
if val == nil {
|
||||
vMap[name] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// mysql driver return []uint8 type for "TEXT", "VARCHAR", and "NVARCHAR"
|
||||
// we'll need to cast it back to string
|
||||
switch colTypes[i].DatabaseTypeName() {
|
||||
case "JSON":
|
||||
// unmarshal JSON data before storing to prevent double marshaling
|
||||
var unmarshaledData any
|
||||
err := json.Unmarshal(val.([]byte), &unmarshaledData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to unmarshal json data %s", val)
|
||||
}
|
||||
vMap[name] = unmarshaledData
|
||||
case "TEXT", "VARCHAR", "NVARCHAR":
|
||||
vMap[name] = string(val.([]byte))
|
||||
default:
|
||||
vMap[name] = val
|
||||
}
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
175
internal/tools/tidb/tidbsql/tidbsql_test.go
Normal file
175
internal/tools/tidb/tidbsql/tidbsql_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidbsql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlTiDB(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: tidb-sql
|
||||
source: my-tidb-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": tidbsql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "tidb-sql",
|
||||
Source: "my-tidb-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromYamlWithTemplateParamsTiDB(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: tidb-sql
|
||||
source: my-tidb-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authServices:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
templateParameters:
|
||||
- name: tableName
|
||||
type: string
|
||||
description: The table to select hotels from.
|
||||
- name: fieldArray
|
||||
type: array
|
||||
description: The columns to return for the query.
|
||||
items:
|
||||
name: column
|
||||
type: string
|
||||
description: A column name that will be returned from the query.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": tidbsql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "tidb-sql",
|
||||
Source: "my-tidb-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthService{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
TemplateParameters: []tools.Parameter{
|
||||
tools.NewStringParameter("tableName", "The table to select hotels from."),
|
||||
tools.NewArrayParameter("fieldArray", "The columns to return for the query.", tools.NewStringParameter("column", "A column name that will be returned from the query.")),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func initBigQueryConnection(project string) (*bigqueryapi.Client, error) {
|
||||
|
||||
func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getBigQueryVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
@@ -100,6 +100,11 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
tableNameDataType := fmt.Sprintf("`%s.%s.datatype_table_%s`",
|
||||
BigqueryProject,
|
||||
datasetName,
|
||||
strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
)
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := getBigQueryParamToolInfo(tableNameParam)
|
||||
@@ -111,8 +116,14 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
teardownTable2 := setupBigQueryTable(t, ctx, client, createAuthTableStmt, insertAuthTableStmt, datasetName, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// set up data for data type test tool
|
||||
createDataTypeTableStmt, insertDataTypeTableStmt, dataTypeToolStmt, arrayDataTypeToolStmt, dataTypeTestParams := getBigQueryDataTypeTestInfo(tableNameDataType)
|
||||
teardownTable3 := setupBigQueryTable(t, ctx, client, createDataTypeTableStmt, insertDataTypeTableStmt, datasetName, tableNameDataType, dataTypeTestParams)
|
||||
defer teardownTable3(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, BigqueryToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addBigQuerySqlToolConfig(t, toolsFile, dataTypeToolStmt, arrayDataTypeToolStmt)
|
||||
toolsFile = addBigQueryPrebuiltToolsConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := getBigQueryTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, BigqueryToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
@@ -135,18 +146,23 @@ func TestBigQueryToolEndpoints(t *testing.T) {
|
||||
|
||||
select1Want := "[{\"f0_\":1}]"
|
||||
// Partial message; the full error message is too long.
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"final query validation failed: failed to insert dry run job: googleapi: Error 400: Syntax error: Unexpected identifier \"SELEC\" at [1:1]`
|
||||
datasetInfoWant := "\"Location\":\"US\",\"DefaultTableExpiration\":0,\"Labels\":null,\"Access\":"
|
||||
tableInfoWant := "{\"Name\":\"\",\"Location\":\"US\",\"Description\":\"\",\"Schema\":[{\"Name\":\"id\""
|
||||
ddlWant := `"Query executed successfully and returned no content."`
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, false, true)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
templateParamTestConfig := tests.NewTemplateParameterTestConfig(
|
||||
tests.WithCreateColArray(`["id INT64", "name STRING", "age INT64"]`),
|
||||
tests.WithDdlWant(ddlWant),
|
||||
tests.WithSelectEmptyWant(`"The query returned 0 rows."`),
|
||||
tests.WithInsert1Want(ddlWant),
|
||||
)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, templateParamTestConfig)
|
||||
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam)
|
||||
runBigQueryExecuteSqlToolInvokeTest(t, select1Want, invokeParamWant, tableNameParam, ddlWant)
|
||||
runBigQueryDataTypeTests(t)
|
||||
runBigQueryListDatasetToolInvokeTest(t, datasetName)
|
||||
runBigQueryGetDatasetInfoToolInvokeTest(t, datasetName, datasetInfoWant)
|
||||
runBigQueryListTableIdsToolInvokeTest(t, datasetName, tableName)
|
||||
@@ -187,6 +203,22 @@ func getBigQueryAuthToolInfo(tableName string) (string, string, string, []bigque
|
||||
return createStatement, insertStatement, toolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryDataTypeTestInfo returns statements and params for data type tests.
|
||||
func getBigQueryDataTypeTestInfo(tableName string) (string, string, string, string, []bigqueryapi.QueryParameter) {
|
||||
createStatement := fmt.Sprintf(`
|
||||
CREATE TABLE IF NOT EXISTS %s (id INT64, int_val INT64, string_val STRING, float_val FLOAT64, bool_val BOOL);`, tableName)
|
||||
insertStatement := fmt.Sprintf(`
|
||||
INSERT INTO %s (id, int_val, string_val, float_val, bool_val) VALUES (?, ?, ?, ?, ?), (?, ?, ?, ?, ?), (?, ?, ?, ?, ?);`, tableName)
|
||||
toolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val = ? AND string_val = ? AND float_val = ? AND bool_val = ?;`, tableName)
|
||||
arrayToolStatement := fmt.Sprintf(`SELECT * FROM %s WHERE int_val IN UNNEST(@int_array) AND string_val IN UNNEST(@string_array) AND float_val IN UNNEST(@float_array) AND bool_val IN UNNEST(@bool_array) ORDER BY id;`, tableName)
|
||||
params := []bigqueryapi.QueryParameter{
|
||||
{Value: int64(1)}, {Value: int64(123)}, {Value: "hello"}, {Value: 3.14}, {Value: true},
|
||||
{Value: int64(2)}, {Value: int64(-456)}, {Value: "world"}, {Value: -0.55}, {Value: false},
|
||||
{Value: int64(3)}, {Value: int64(789)}, {Value: "test"}, {Value: 100.1}, {Value: true},
|
||||
}
|
||||
return createStatement, insertStatement, toolStatement, arrayToolStatement, params
|
||||
}
|
||||
|
||||
// getBigQueryTmplToolStatement returns statements for template parameter test cases for bigquery kind
|
||||
func getBigQueryTmplToolStatement() (string, string) {
|
||||
tmplSelectCombined := "SELECT * FROM {{.tableName}} WHERE id = ? ORDER BY id"
|
||||
@@ -345,7 +377,41 @@ func addBigQueryPrebuiltToolsConfig(t *testing.T, config map[string]any) map[str
|
||||
return config
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam string) {
|
||||
func addBigQuerySqlToolConfig(t *testing.T, config map[string]any, toolStatement, arrayToolStatement string) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["my-scalar-datatype-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test various scalar data types.",
|
||||
"statement": toolStatement,
|
||||
"parameters": []any{
|
||||
map[string]any{"name": "int_val", "type": "integer", "description": "an integer value"},
|
||||
map[string]any{"name": "string_val", "type": "string", "description": "a string value"},
|
||||
map[string]any{"name": "float_val", "type": "float", "description": "a float value"},
|
||||
map[string]any{"name": "bool_val", "type": "boolean", "description": "a boolean value"},
|
||||
},
|
||||
}
|
||||
tools["my-array-datatype-tool"] = map[string]any{
|
||||
"kind": "bigquery-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to test various array data types.",
|
||||
"statement": arrayToolStatement,
|
||||
"parameters": []any{
|
||||
map[string]any{"name": "int_array", "type": "array", "description": "an array of integer values", "items": map[string]any{"name": "item", "type": "integer", "description": "desc"}},
|
||||
map[string]any{"name": "string_array", "type": "array", "description": "an array of string values", "items": map[string]any{"name": "item", "type": "string", "description": "desc"}},
|
||||
map[string]any{"name": "float_array", "type": "array", "description": "an array of float values", "items": map[string]any{"name": "item", "type": "float", "description": "desc"}},
|
||||
map[string]any{"name": "bool_array", "type": "array", "description": "an array of boolean values", "items": map[string]any{"name": "item", "type": "boolean", "description": "desc"}},
|
||||
},
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWant, tableNameParam, ddlWant string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
@@ -381,7 +447,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"}`)),
|
||||
want: `"Operation completed successfully."`,
|
||||
want: ddlWant,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
@@ -405,7 +471,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"sql":"DROP TABLE t"}`)),
|
||||
want: `"Operation completed successfully."`,
|
||||
want: ddlWant,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
@@ -413,7 +479,7 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
api: "http://127.0.0.1:5000/api/tool/my-exec-sql-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"sql\":\"INSERT INTO %s (id, name) VALUES (4, 'test_name')\"}", tableNameParam))),
|
||||
want: `"Operation completed successfully."`,
|
||||
want: ddlWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -490,6 +556,84 @@ func runBigQueryExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamW
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryDataTypeTests(t *testing.T) {
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
isErr bool
|
||||
}{
|
||||
{
|
||||
name: "invoke my-scalar-datatype-tool with values",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123, "string_val": "hello", "float_val": 3.14, "bool_val": true}`)),
|
||||
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"}]`,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke my-scalar-datatype-tool with missing params",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-scalar-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_val": 123}`)),
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "invoke my-array-datatype-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-array-datatype-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"int_array": [123, 789], "string_array": ["hello", "test"], "float_array": [3.14, 100.1], "bool_array": [true]}`)),
|
||||
want: `[{"bool_val":true,"float_val":3.14,"id":1,"int_val":123,"string_val":"hello"},{"bool_val":true,"float_val":100.1,"id":3,"int_val":789,"string_val":"test"}]`,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Send Tool invocation request
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if tc.isErr {
|
||||
return
|
||||
}
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runBigQueryListDatasetToolInvokeTest(t *testing.T, datasetWant string) {
|
||||
// Get ID token
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -39,6 +40,7 @@ import (
|
||||
var (
|
||||
DataplexSourceKind = "dataplex"
|
||||
DataplexSearchEntriesToolKind = "dataplex-search-entries"
|
||||
DataplexLookupEntryToolKind = "dataplex-lookup-entry"
|
||||
DataplexProject = os.Getenv("DATAPLEX_PROJECT")
|
||||
)
|
||||
|
||||
@@ -69,7 +71,7 @@ func initBigQueryConnection(ctx context.Context, project string) (*bigqueryapi.C
|
||||
|
||||
func TestDataplexToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getDataplexVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
@@ -94,7 +96,7 @@ func TestDataplexToolEndpoints(t *testing.T) {
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 3*time.Minute)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
@@ -102,8 +104,9 @@ func TestDataplexToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
runDataplexSearchEntriesToolGetTest(t)
|
||||
runDataplexToolGetTest(t)
|
||||
runDataplexSearchEntriesToolInvokeTest(t, tableName, datasetName)
|
||||
runDataplexLookupEntryToolInvokeTest(t, tableName, datasetName)
|
||||
}
|
||||
|
||||
func setupBigQueryTable(t *testing.T, ctx context.Context, client *bigqueryapi.Client, datasetName string, tableName string) func(*testing.T) {
|
||||
@@ -169,92 +172,169 @@ func getDataplexToolsConfig(sourceConfig map[string]any) map[string]any {
|
||||
"sources": map[string]any{
|
||||
"my-dataplex-instance": sourceConfig,
|
||||
},
|
||||
"authServices": map[string]any{
|
||||
"my-google-auth": map[string]any{
|
||||
"kind": "google",
|
||||
"clientId": tests.ClientId,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-search-entries-tool": map[string]any{
|
||||
"my-dataplex-search-entries-tool": map[string]any{
|
||||
"kind": DataplexSearchEntriesToolKind,
|
||||
"source": "my-dataplex-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"my-auth-dataplex-search-entries-tool": map[string]any{
|
||||
"kind": DataplexSearchEntriesToolKind,
|
||||
"source": "my-dataplex-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
"my-dataplex-lookup-entry-tool": map[string]any{
|
||||
"kind": DataplexLookupEntryToolKind,
|
||||
"source": "my-dataplex-instance",
|
||||
"description": "Simple dataplex lookup entry tool to test end to end functionality.",
|
||||
},
|
||||
"my-auth-dataplex-lookup-entry-tool": map[string]any{
|
||||
"kind": DataplexLookupEntryToolKind,
|
||||
"source": "my-dataplex-instance",
|
||||
"description": "Simple dataplex lookup entry tool to test end to end functionality.",
|
||||
"authRequired": []string{"my-google-auth"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return toolsFile
|
||||
}
|
||||
|
||||
func runDataplexSearchEntriesToolGetTest(t *testing.T) {
|
||||
resp, err := http.Get("http://127.0.0.1:5000/api/tool/my-search-entries-tool/")
|
||||
if err != nil {
|
||||
t.Fatalf("error making GET request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("expected status code 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var body map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
t.Fatalf("error decoding response body: %s", err)
|
||||
}
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find 'tools' key in response body")
|
||||
func runDataplexToolGetTest(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
toolName string
|
||||
expectedParams []string
|
||||
}{
|
||||
{
|
||||
name: "get my-dataplex-search-entries-tool",
|
||||
toolName: "my-dataplex-search-entries-tool",
|
||||
expectedParams: []string{"pageSize", "pageToken", "query", "orderBy", "semanticSearch"},
|
||||
},
|
||||
{
|
||||
name: "get my-dataplex-lookup-entry-tool",
|
||||
toolName: "my-dataplex-lookup-entry-tool",
|
||||
expectedParams: []string{"name", "view", "aspectTypes", "entry"},
|
||||
},
|
||||
}
|
||||
|
||||
toolsMap, ok := got.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("tools is not a map")
|
||||
}
|
||||
tool, ok := toolsMap["my-search-entries-tool"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("tool not found in manifest")
|
||||
}
|
||||
params, ok := tool["parameters"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("parameters not found")
|
||||
}
|
||||
paramNames := []string{}
|
||||
for _, param := range params {
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if ok {
|
||||
paramNames = append(paramNames, paramMap["name"].(string))
|
||||
}
|
||||
}
|
||||
expected := []string{"name", "pageSize", "pageToken", "orderBy", "query"}
|
||||
for _, want := range expected {
|
||||
found := false
|
||||
for _, got := range paramNames {
|
||||
if got == want {
|
||||
found = true
|
||||
break
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/", tc.toolName))
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected parameter %q not found in tool parameters", want)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
|
||||
toolsMap, ok := got.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected 'tools' to be a map, got %T", got)
|
||||
}
|
||||
tool, ok := toolsMap[tc.toolName].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected tool %q to be a map, got %T", tc.toolName, toolsMap[tc.toolName])
|
||||
}
|
||||
params, ok := tool["parameters"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected 'parameters' to be a slice, got %T", tool["parameters"])
|
||||
}
|
||||
paramSet := make(map[string]struct{})
|
||||
for _, param := range params {
|
||||
paramMap, ok := param.(map[string]interface{})
|
||||
if ok {
|
||||
if name, ok := paramMap["name"].(string); ok {
|
||||
paramSet[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
var missing []string
|
||||
for _, want := range tc.expectedParams {
|
||||
if _, found := paramSet[want]; !found {
|
||||
missing = append(missing, want)
|
||||
}
|
||||
}
|
||||
if len(missing) > 0 {
|
||||
t.Fatalf("missing parameters for tool %q: %v", tc.toolName, missing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runDataplexSearchEntriesToolInvokeTest(t *testing.T, tableName string, datasetName string) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tableName string
|
||||
datasetName string
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
expectResult bool
|
||||
wantContentKey string
|
||||
}{
|
||||
{
|
||||
name: "Success - Entry Found",
|
||||
tableName: tableName,
|
||||
datasetName: datasetName,
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-search-entries-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"query\":\"displayname=%s system=bigquery parent=%s\"}", tableName, datasetName))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "dataplex_entry",
|
||||
},
|
||||
{
|
||||
name: "Success with Authorization - Entry Found",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-search-entries-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"query\":\"displayname=%s system=bigquery parent=%s\"}", tableName, datasetName))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "dataplex_entry",
|
||||
},
|
||||
{
|
||||
name: "Failure - Invalid Authorization Token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-search-entries-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "invalid_token"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"query\":\"displayname=%s system=bigquery parent=%s\"}", tableName, datasetName))),
|
||||
wantStatusCode: 401,
|
||||
expectResult: false,
|
||||
wantContentKey: "dataplex_entry",
|
||||
},
|
||||
{
|
||||
name: "Failure - Without Authorization Token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-search-entries-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"query\":\"displayname=%s system=bigquery parent=%s\"}", tableName, datasetName))),
|
||||
wantStatusCode: 401,
|
||||
expectResult: false,
|
||||
wantContentKey: "dataplex_entry",
|
||||
},
|
||||
{
|
||||
name: "Failure - Entry Not Found",
|
||||
tableName: "",
|
||||
datasetName: "",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-search-entries-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(`{"query":"displayname=\"\" system=bigquery parent=\"\""}`)),
|
||||
wantStatusCode: 200,
|
||||
expectResult: false,
|
||||
wantContentKey: "",
|
||||
@@ -263,19 +343,23 @@ func runDataplexSearchEntriesToolInvokeTest(t *testing.T, tableName string, data
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
query := fmt.Sprintf("displayname=\"%s\" system=bigquery parent:\"%s\"", tc.tableName, tc.datasetName)
|
||||
reqBodyMap := map[string]string{"query": query}
|
||||
reqBodyBytes, err := json.Marshal(reqBodyMap)
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error marshalling request body: %s", err)
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
resp, err := http.Post("http://127.0.0.1:5000/api/tool/my-search-entries-tool/invoke", "application/json", bytes.NewBuffer(reqBodyBytes))
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("error making POST request: %s", err)
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("response status code is not %d.", tc.wantStatusCode)
|
||||
t.Fatalf("response status code is not %d. It is %d", tc.wantStatusCode, resp.StatusCode)
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("Response body: %s", string(bodyBytes))
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
@@ -297,8 +381,8 @@ func runDataplexSearchEntriesToolInvokeTest(t *testing.T, tableName string, data
|
||||
}
|
||||
|
||||
if tc.expectResult {
|
||||
if len(entries) == 0 {
|
||||
t.Fatal("expected at least one entry, but got 0")
|
||||
if len(entries) != 1 {
|
||||
t.Fatalf("expected exactly one entry, but got %d", len(entries))
|
||||
}
|
||||
entry, ok := entries[0].(map[string]interface{})
|
||||
if !ok {
|
||||
@@ -315,3 +399,163 @@ func runDataplexSearchEntriesToolInvokeTest(t *testing.T, tableName string, data
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func runDataplexLookupEntryToolInvokeTest(t *testing.T, tableName string, datasetName string) {
|
||||
idToken, err := tests.GetGoogleIdToken(tests.ClientId)
|
||||
if err != nil {
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantStatusCode int
|
||||
api string
|
||||
requestHeader map[string]string
|
||||
requestBody io.Reader
|
||||
expectResult bool
|
||||
wantContentKey string
|
||||
dontWantContentKey string
|
||||
aspectCheck bool
|
||||
reqBodyMap map[string]any
|
||||
}{
|
||||
{
|
||||
name: "Success - Entry Found",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, datasetName))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "name",
|
||||
},
|
||||
{
|
||||
name: "Success - Entry Found with Authorization",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, datasetName))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "name",
|
||||
},
|
||||
{
|
||||
name: "Failure - Invalid Authorization Token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{"my-google-auth_token": "invalid_token"},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, datasetName))),
|
||||
wantStatusCode: 401,
|
||||
expectResult: false,
|
||||
wantContentKey: "name",
|
||||
},
|
||||
{
|
||||
name: "Failure - Without Authorization Token",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-auth-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, datasetName))),
|
||||
wantStatusCode: 401,
|
||||
expectResult: false,
|
||||
wantContentKey: "name",
|
||||
},
|
||||
{
|
||||
name: "Failure - Entry Not Found or Permission Denied",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s\"}", DataplexProject, DataplexProject, DataplexProject, "non-existent-dataset"))),
|
||||
wantStatusCode: 400,
|
||||
expectResult: false,
|
||||
},
|
||||
{
|
||||
name: "Success - Entry Found with Basic View",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s/tables/%s\", \"view\": %d}", DataplexProject, DataplexProject, DataplexProject, datasetName, tableName, 1))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "name",
|
||||
dontWantContentKey: "aspects",
|
||||
},
|
||||
{
|
||||
name: "Failure - Entry with Custom View without Aspect Types",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s/tables/%s\", \"view\": %d}", DataplexProject, DataplexProject, DataplexProject, datasetName, tableName, 3))),
|
||||
wantStatusCode: 400,
|
||||
expectResult: false,
|
||||
},
|
||||
{
|
||||
name: "Success - Entry Found with only Schema Aspect",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-dataplex-lookup-entry-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf("{\"name\":\"projects/%s/locations/us\", \"entry\":\"projects/%s/locations/us/entryGroups/@bigquery/entries/bigquery.googleapis.com/projects/%s/datasets/%s/tables/%s\", \"aspectTypes\":[\"projects/dataplex-types/locations/global/aspectTypes/schema\"], \"view\": %d}", DataplexProject, DataplexProject, DataplexProject, datasetName, tableName, 3))),
|
||||
wantStatusCode: 200,
|
||||
expectResult: true,
|
||||
wantContentKey: "aspects",
|
||||
aspectCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, tc.api, tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
for k, v := range tc.requestHeader {
|
||||
req.Header.Add(k, v)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("Response status code got %d, want %d\nResponse body: %s", resp.StatusCode, tc.wantStatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("Error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
if tc.expectResult {
|
||||
resultStr, ok := result["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("Expected 'result' field to be a string on success, got %T", result["result"])
|
||||
}
|
||||
if resultStr == "" || resultStr == "{}" || resultStr == "null" {
|
||||
t.Fatal("Expected an entry, but got empty result")
|
||||
}
|
||||
|
||||
var entry map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(resultStr), &entry); err != nil {
|
||||
t.Fatalf("Error unmarshalling result string into entry map: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := entry[tc.wantContentKey]; !ok {
|
||||
t.Fatalf("Expected entry to have key '%s', but it was not found in %v", tc.wantContentKey, entry)
|
||||
}
|
||||
|
||||
if _, ok := entry[tc.dontWantContentKey]; ok {
|
||||
t.Fatalf("Expected entry to not have key '%s', but it was found in %v", tc.dontWantContentKey, entry)
|
||||
}
|
||||
|
||||
if tc.aspectCheck {
|
||||
// Check length of aspects
|
||||
aspects, ok := entry["aspects"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected 'aspects' to be a map, got %T", aspects)
|
||||
}
|
||||
if len(aspects) != 1 {
|
||||
t.Fatalf("Expected exactly one aspect, but got %d", len(aspects))
|
||||
}
|
||||
}
|
||||
} else { // Handle expected error response
|
||||
_, ok := result["error"]
|
||||
if !ok {
|
||||
t.Fatalf("Expected 'error' field in response, got %v", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
164
tests/tidb/tidb_integration_test.go
Normal file
164
tests/tidb/tidb_integration_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tidb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
TiDBSourceKind = "tidb"
|
||||
TiDBToolKind = "tidb-sql"
|
||||
TiDBDatabase = os.Getenv("TIDB_DATABASE")
|
||||
TiDBHost = os.Getenv("TIDB_HOST")
|
||||
TiDBPort = os.Getenv("TIDB_PORT")
|
||||
TiDBUser = os.Getenv("TIDB_USER")
|
||||
TiDBPass = os.Getenv("TIDB_PASS")
|
||||
)
|
||||
|
||||
func getTiDBVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case TiDBDatabase:
|
||||
t.Fatal("'TIDB_DATABASE' not set")
|
||||
case TiDBHost:
|
||||
t.Fatal("'TIDB_HOST' not set")
|
||||
case TiDBPort:
|
||||
t.Fatal("'TIDB_PORT' not set")
|
||||
case TiDBUser:
|
||||
t.Fatal("'TIDB_USER' not set")
|
||||
case TiDBPass:
|
||||
t.Fatal("'TIDB_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": TiDBSourceKind,
|
||||
"host": TiDBHost,
|
||||
"port": TiDBPort,
|
||||
"database": TiDBDatabase,
|
||||
"user": TiDBUser,
|
||||
"password": TiDBPass,
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from tidb.go
|
||||
func initTiDBConnectionPool(host, port, user, pass, dbname string, useSSL bool) (*sql.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true&charset=utf8mb4&tls=%t", user, pass, host, port, dbname, useSSL)
|
||||
|
||||
// Interact with the driver directly as you normally would
|
||||
pool, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sql.Open: %w", err)
|
||||
}
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// getTiDBWants return the expected wants for tidb
|
||||
func getTiDBWants() (string, string, string) {
|
||||
select1Want := "[{\"1\":1}]"
|
||||
failInvocationWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: Error 1064 (42000): You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 5 near \"SELEC 1;\" "}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id SERIAL PRIMARY KEY, name TEXT)"`
|
||||
return select1Want, failInvocationWant, createTableStatement
|
||||
}
|
||||
|
||||
// addTiDBExecuteSqlConfig gets the tools config for `tidb-execute-sql`
|
||||
func addTiDBExecuteSqlConfig(t *testing.T, config map[string]any) map[string]any {
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
tools["my-exec-sql-tool"] = map[string]any{
|
||||
"kind": "tidb-execute-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to execute sql",
|
||||
}
|
||||
tools["my-auth-exec-sql-tool"] = map[string]any{
|
||||
"kind": "tidb-execute-sql",
|
||||
"source": "my-instance",
|
||||
"description": "Tool to execute sql",
|
||||
"authRequired": []string{
|
||||
"my-google-auth",
|
||||
},
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
func TestTiDBToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getTiDBVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
pool, err := initTiDBConnectionPool(TiDBHost, TiDBPort, TiDBUser, TiDBPass, TiDBDatabase, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create TiDB connection pool: %s", err)
|
||||
}
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
// set up data for param tool
|
||||
createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetMySQLParamToolInfo(tableNameParam)
|
||||
teardownTable1 := tests.SetupMySQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetMySQLAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := tests.SetupMySQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := tests.GetToolsConfig(sourceConfig, TiDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt)
|
||||
toolsFile = addTiDBExecuteSqlConfig(t, toolsFile)
|
||||
tmplSelectCombined, tmplSelectFilterCombined := tests.GetMySQLTmplToolStatement()
|
||||
toolsFile = tests.AddTemplateParamConfig(t, toolsFile, TiDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "")
|
||||
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
tests.RunToolGetTest(t)
|
||||
|
||||
select1Want, failInvocationWant, createTableStatement := getTiDBWants()
|
||||
invokeParamWant, invokeIdNullWant, nullWant, mcpInvokeParamWant := tests.GetNonSpannerInvokeParamWant()
|
||||
tests.RunToolInvokeTest(t, select1Want, invokeParamWant, invokeIdNullWant, nullWant, true, false)
|
||||
tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want)
|
||||
tests.RunMCPToolCallMethod(t, mcpInvokeParamWant, failInvocationWant)
|
||||
tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.NewTemplateParameterTestConfig())
|
||||
}
|
||||
@@ -405,14 +405,16 @@ func RunToolInvokeTest(t *testing.T, select1Want, invokeParamWant, invokeIdNullW
|
||||
|
||||
// TemplateParameterTestConfig represents the various configuration options for template parameter tests.
|
||||
type TemplateParameterTestConfig struct {
|
||||
ignoreDdl bool
|
||||
ignoreInsert bool
|
||||
selectAllWant string
|
||||
select1Want string
|
||||
nameFieldArray string
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
insert1Want string
|
||||
ignoreDdl bool
|
||||
ignoreInsert bool
|
||||
ddlWant string
|
||||
selectAllWant string
|
||||
select1Want string
|
||||
selectEmptyWant string
|
||||
nameFieldArray string
|
||||
nameColFilter string
|
||||
createColArray string
|
||||
insert1Want string
|
||||
}
|
||||
|
||||
type Option func(*TemplateParameterTestConfig)
|
||||
@@ -431,6 +433,13 @@ func WithIgnoreInsert() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithDdlWant is the option function to configure ddlWant.
|
||||
func WithDdlWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.ddlWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectAllWant is the option function to configure selectAllWant.
|
||||
func WithSelectAllWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
@@ -445,6 +454,13 @@ func WithSelect1Want(s string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithSelectEmptyWant is the option function to configure selectEmptyWant.
|
||||
func WithSelectEmptyWant(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
c.selectEmptyWant = s
|
||||
}
|
||||
}
|
||||
|
||||
// WithReplaceNameFieldArray is the option function to configure replaceNameFieldArray.
|
||||
func WithReplaceNameFieldArray(s string) Option {
|
||||
return func(c *TemplateParameterTestConfig) {
|
||||
@@ -475,14 +491,16 @@ func WithInsert1Want(s string) Option {
|
||||
// NewTemplateParameterTestConfig creates a new TemplateParameterTestConfig instances with options.
|
||||
func NewTemplateParameterTestConfig(options ...Option) *TemplateParameterTestConfig {
|
||||
templateParamTestOption := &TemplateParameterTestConfig{
|
||||
ignoreDdl: false,
|
||||
ignoreInsert: false,
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
select1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
nameFieldArray: `["name"]`,
|
||||
nameColFilter: "name",
|
||||
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
||||
insert1Want: "null",
|
||||
ignoreDdl: false,
|
||||
ignoreInsert: false,
|
||||
ddlWant: "null",
|
||||
selectAllWant: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"},{\"age\":100,\"id\":2,\"name\":\"Alice\"}]",
|
||||
select1Want: "[{\"age\":21,\"id\":1,\"name\":\"Alex\"}]",
|
||||
selectEmptyWant: "null",
|
||||
nameFieldArray: `["name"]`,
|
||||
nameColFilter: "name",
|
||||
createColArray: `["id INT","name VARCHAR(20)","age INT"]`,
|
||||
insert1Want: "null",
|
||||
}
|
||||
|
||||
// Apply provided options
|
||||
@@ -514,7 +532,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
api: "http://127.0.0.1:5000/api/tool/create-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s", "columns":%s}`, tableName, config.createColArray))),
|
||||
want: "null",
|
||||
want: config.ddlWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
@@ -551,6 +569,14 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
want: config.select1Want,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-templateParams-combined-tool with no results",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-templateParams-combined-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"id": 999, "tableName": "%s"}`, tableName))),
|
||||
want: config.selectEmptyWant,
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "invoke select-fields-templateParams-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/select-fields-templateParams-tool/invoke",
|
||||
@@ -573,7 +599,7 @@ func RunToolInvokeWithTemplateParameters(t *testing.T, tableName string, config
|
||||
api: "http://127.0.0.1:5000/api/tool/drop-table-templateParams-tool/invoke",
|
||||
requestHeader: map[string]string{},
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"tableName": "%s"}`, tableName))),
|
||||
want: "null",
|
||||
want: config.ddlWant,
|
||||
isErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user