feat: Add Cloud SQL for SQL Server Source and Tool (#223)

1. `sql/database` provides a `Scan()`interface to scan query results
into typed variables. Therefore we have to create a slice of typed
variables (types retrieved from rows.ColumnTypes()) to pass them into
`Scan()`. Using []byte works but makes the printing result different
from other tools (e.g [1] instead of %!s(int32=1)]
2. MS SQL supports both named (e.g @name) and positional args (e.g @p2),
so we have to check if the name is contained in the original statement
before passing them into `db.Query()` as either named arg or as values.
This commit is contained in:
Wenxin Du
2025-01-23 21:21:12 +08:00
committed by GitHub
parent 1de3853006
commit 9bad952060
15 changed files with 1093 additions and 9 deletions

View File

@@ -118,6 +118,27 @@ steps:
- |
go test -race -v -tags=integration,neo4j ./tests
- id: "cloud-sql-mssql"
name: golang:1
waitFor: ["install-dependencies"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "CLOUD_SQL_MSSQL_PROJECT=$PROJECT_ID"
- "CLOUD_SQL_MSSQL_INSTANCE=$_CLOUD_SQL_MSSQL_INSTANCE"
- "CLOUD_SQL_MSSQL_IP=$_CLOUD_SQL_MSSQL_IP"
- "CLOUD_SQL_MSSQL_DATABASE=$_DATABASE_NAME"
- "CLOUD_SQL_MSSQL_REGION=$_REGION"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLOUD_SQL_MSSQL_USER", "CLOUD_SQL_MSSQL_PASS", "CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
go test -race -v -tags=integration,cloudsqlmssql ./tests
availableSecrets:
secretManager:
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
@@ -138,6 +159,10 @@ availableSecrets:
env: NEO4J_USER
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
env: NEO4J_PASS
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_user/versions/latest
env: CLOUD_SQL_MSSQL_USER
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_pass/versions/latest
env: CLOUD_SQL_MSSQL_PASS
options:
logging: CLOUD_LOGGING_ONLY
@@ -157,3 +182,4 @@ substitutions:
_POSTGRES_PORT: "5432"
_SPANNER_INSTANCE: "spanner-testing"
_NEO4J_DATABASE: "neo4j"
_CLOUD_SQL_MSSQL_INSTANCE: "cloud-sql-mssql-testing"

View File

@@ -0,0 +1,73 @@
# Cloud SQL for SQL Server Source
[Cloud SQL for SQL Server][csql-mssql-docs] is a managed database service that
helps you set up, maintain, manage, and administer your SQL Server databases on
Google Cloud.
If you are new to Cloud SQL for SQL Server, you can try [creating and connecting
to a database by following these instructions][csql-mssql-connect].
[csql-mssql-docs]: https://cloud.google.com/sql/docs/sqlserver
[csql-mssql-connect]: https://cloud.google.com/sql/docs/sqlserver/connect-overview
## Requirements
### IAM Identity
By default, this source uses the [Cloud SQL Go Connector][csql-go-conn] to
authorize and establish mTLS connections to your Cloud SQL instance. The Go
connector uses your [Application Default Credentials (ADC)][adc] to authorize
your connection to Cloud SQL.
In addition to [setting the ADC for your server][set-adc], you need to ensure the
IAM identity has been given the following IAM roles:
- `roles/cloudsql.client`
[csql-go-conn]: https://github.com/GoogleCloudPlatform/cloud-sql-go-connector
[adc]: https://cloud.google.com/docs/authentication#adc
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
### Network Path
Currently, Cloud SQL for SQL Server supports connection over both [private IP][private-ip] and
[public IP][public-ip]. Set the `ipType` parameter in your source
configuration to `public` or `private`.
[private-ip]: https://cloud.google.com/sql/docs/sqlserver/configure-private-ip
[public-ip]: https://cloud.google.com/sql/docs/sqlserver/configure-ip
### Database User
Currently, this source only uses standard authentication. You will need to [create a
SQL Server user][cloud-sql-users] to login to the database with.
[cloud-sql-users]: https://cloud.google.com/sql/docs/sqlserver/create-manage-users
## Example
```yaml
sources:
my-cloud-sql-mssql-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
ipAddress: localhost
ipType: public
```
## Reference
| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|------------------------------------------------------------------------------|
| kind | string | true | Must be "cloud-sql-postgres". |
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
| region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). |
| instance | string | true | Name of the Cloud SQL instance within the cluser (e.g. "my-instance"). |
| database | string | true | Name of the Cloud SQL database to connect to (e.g. "my_db"). |
| ipAddress | string | true | IP address of the Cloud SQL instance to connect to.|
| ipType | string | true | IP Type of the Cloud SQL instance, must be either `public` or `private`. Default: `public`. |
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
| password | string | true | Password of the Postgres user (e.g. "my-password").

66
docs/tools/mssql.md Normal file
View File

@@ -0,0 +1,66 @@
# Cloud SQL Mssql Tool
A "mssql" tool executes a pre-defined SQL statement against a Cloud SQL for SQL Server
database. It's compatible with any of the following sources:
- [cloud-sql-mssql](../sources/cloud-sql-mssql.md)
Toolbox supports the [prepare statement syntax][prepare-statement] of MS SQL
Server and expects parameters in the SQL query to be in the form of either @Name
or @p1 to @pN (ordinal position).
```sql
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
```
[prepare-statement]: https://learn.microsoft.com/sql/relational-databases/system-stored-procedures/sp-prepare-transact-sql?view=sql-server-ver16
## Example
```yaml
tools:
search_flights_by_number:
kind: mssql
source: my-instance
statement: |
SELECT * FROM flights
WHERE airline = @airline
AND flight_number = @number
LIMIT 10
description: |
Use this tool to get information for a specific flight.
Takes an airline code and flight number and returns info on the flight.
Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number.
A airline code is a code for an airline service consisting of two-character
airline designator and followed by flight number, which is 1 to 4 digit number.
For example, if given CY 0123, the airline is "CY", and flight_number is "123".
Another example for this is DL 1234, the airline is "DL", and flight_number is "1234".
If the tool returns more than one option choose the date closes to today.
Example:
{{
"airline": "CY",
"flight_number": "888",
}}
Example:
{{
"airline": "DL",
"flight_number": "1234",
}}
parameters:
- name: airline
type: string
description: Airline unique 2 letter identifier
- name: number
type: string
description: 1 to 4 digit number
```
## Reference
| **field** | **type** | **required** | **description** |
|-------------|:--------------------------------------------:|:------------:|-----------------------------------------------------------------------------------------------------|
| kind | string | true | Must be "mssql". |
| source | string | true | Name of the source the T-SQL statement should execute on.|
| description | string | true | Description of the tool that is passed to the LLM|
| statement | string | true | SQL statement to execute. |
| parameters | [parameter](README.md#specifying-parameters) | true | List of [parameters](README.md#specifying-parameters) that will be inserted into the SQL statement. |

3
go.mod
View File

@@ -56,6 +56,8 @@ require (
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
@@ -65,6 +67,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/microsoft/go-mssqldb v1.8.0 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect

18
go.sum
View File

@@ -625,6 +625,18 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 h1:U2rTu3Ef+7w9FHKIAXM6ZyqF3UOWJZ12zIm8zECAFfg=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 h1:jBQA3cKT4L2rWMpgE7Yt3Hwh2aUj8KXjIGLxjHeYNNo=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1 h1:MyVTgWR8qd/Jw1Le0NZebGBUCLbtak3bJ3z1OlqZBpw=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80=
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.2 h1:DBjmt6/otSdULyJdVg2BlG0qGZO5tKL4VzOs0jpvw5Q=
@@ -743,6 +755,8 @@ github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqw
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-yaml v1.15.13 h1:Xd87Yddmr2rC1SLLTm2MNDcTjeO/GYo0JGiww6gSTDg=
github.com/goccy/go-yaml v1.15.13/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
@@ -910,6 +924,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lyft/protoc-gen-star v0.6.0/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA=
github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA=
github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o=
@@ -926,6 +942,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -22,12 +22,14 @@ import (
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/sources"
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
"github.com/googleapis/genai-toolbox/internal/tools/mysql"
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
@@ -173,6 +175,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case cloudsqlmssqlsrc.SourceKind:
actual := cloudsqlmssqlsrc.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}
@@ -264,6 +272,12 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case mssql.ToolKind:
actual := mssql.Config{Name: name}
if err := u.Unmarshal(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
}

View File

@@ -0,0 +1,136 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlmssql
import (
"context"
"database/sql"
"fmt"
"slices"
"strings"
"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
"github.com/googleapis/genai-toolbox/internal/sources"
"go.opentelemetry.io/otel/trace"
)
const SourceKind string = "cloud-sql-mssql"
// validate interface
var _ sources.SourceConfig = Config{}
type Config struct {
// Cloud SQL MSSQL configs
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Region string `yaml:"region"`
Instance string `yaml:"instance"`
IPAddress string `yaml:"ipAddress"`
IPType string `yaml:"ipType"`
User string `yaml:"user"`
Password string `yaml:"password"`
Database string `yaml:"database"`
}
func (r Config) SourceConfigKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
// Initializes a Cloud SQL MSSQL source
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPAddress, r.IPType, r.User, r.Password, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create db connection: %w", err)
}
// Verify db connection
err = db.Ping()
if err != nil {
return nil, fmt.Errorf("unable to connect successfully: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Db: db,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
// Cloud SQL MSSQL struct with connection pool
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Db *sql.DB
}
func (s *Source) SourceKind() string {
// Returns Cloud SQL MSSQL source kind
return SourceKind
}
func (s *Source) MSSQLDB() *sql.DB {
// Returns a Cloud SQL MSSQL database connection pool
return s.Db
}
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
//nolint:all // Reassigned ctx
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
defer span.End()
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
// Get dial options
dialOpts, err := getDialOpts(ipType)
if err != nil {
return nil, err
}
// Register sql server driver
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
if err != nil {
return nil, err
}
}
// Open database connection
db, err := sql.Open(
"cloudsql-sqlserver-driver",
dsn,
)
if err != nil {
return nil, err
}
return db, nil
}

View File

@@ -0,0 +1,76 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudsqlmssql_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/testutils"
"gopkg.in/yaml.v3"
)
func TestParseFromYamlCloudSQLMssql(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-instance:
kind: cloud-sql-mssql
project: my-project
region: my-region
instance: my-instance
database: my_db
ipAddress: localhost
ipType: public
`,
want: server.SourceConfigs{
"my-instance": cloudsqlmssql.Config{
Name: "my-instance",
Kind: cloudsqlmssql.SourceKind,
Project: "my-project",
Region: "my-region",
Instance: "my-instance",
IPAddress: "localhost",
IPType: "public",
Database: "my_db",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
}
})
}
}

View File

@@ -0,0 +1,168 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mssql
import (
"database/sql"
"fmt"
"strings"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
"github.com/googleapis/genai-toolbox/internal/tools"
)
const ToolKind string = "mssql"
type compatibleSource interface {
MSSQLDB() *sql.DB
}
// validate compatible sources are still compatible
var _ compatibleSource = &cloudsqlmssql.Source{}
var compatibleSources = [...]string{cloudsqlmssql.SourceKind}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return ToolKind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
AuthRequired: cfg.AuthRequired,
Db: s.MSSQLDB(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
}
return t, nil
}
func NewGenericTool(name string, stmt string, authRequired []string, desc string, Db *sql.DB, parameters tools.Parameters) Tool {
return Tool{
Name: name,
Kind: ToolKind,
Statement: stmt,
AuthRequired: authRequired,
Db: Db,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
}
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
AuthRequired []string `yaml:"authRequired"`
Parameters tools.Parameters `yaml:"parameters"`
Db *sql.DB
Statement string
manifest tools.Manifest
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
fmt.Printf("Invoked tool %s\n", t.Name)
namedArgs := make([]any, 0, len(params))
paramsMap := params.AsReversedMap()
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
for _, v := range params.AsSlice() {
paramName := paramsMap[v]
if strings.Contains(t.Statement, "@"+paramName) {
namedArgs = append(namedArgs, sql.Named(paramName, v))
} else {
namedArgs = append(namedArgs, v)
}
}
rows, err := t.Db.Query(t.Statement, namedArgs...)
if err != nil {
return "", fmt.Errorf("unable to execute query: %w", err)
}
types, err := rows.ColumnTypes()
if err != nil {
return "", fmt.Errorf("unable to fetch column types: %w", err)
}
v := make([]any, len(types))
pointers := make([]any, len(types))
for i := range types {
pointers[i] = &v[i]
}
// fetch result into a string
var out strings.Builder
for rows.Next() {
err = rows.Scan(pointers...)
if err != nil {
return "", fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(fmt.Sprintf("%s", v))
}
err = rows.Close()
if err != nil {
return "", fmt.Errorf("unable to close rows: %w", err)
}
// Check if error occured during iteration
if err := rows.Err(); err != nil {
return "", err
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) Authorized(verifiedAuthSources []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
}

View File

@@ -0,0 +1,90 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mssql_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
"gopkg.in/yaml.v3"
)
func TestParseFromYamlMssql(t *testing.T) {
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: mssql
source: my-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
authRequired:
- my-google-auth-service
- other-auth-service
parameters:
- name: country
type: string
description: some description
authSources:
- name: my-google-auth-service
field: user_id
- name: other-auth-service
field: user_id
`,
want: server.ToolConfigs{
"example_tool": mssql.Config{
Name: "example_tool",
Kind: mssql.ToolKind,
Source: "my-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
Parameters: []tools.Parameter{
tools.NewStringParameterWithAuth("country", "some description",
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
{Name: "other-auth-service", Field: "user_id"}}),
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -57,6 +57,15 @@ func (p ParamValues) AsMap() map[string]interface{} {
return params
}
// AsReversedMap returns a map of ParamValue's values to names.
func (p ParamValues) AsReversedMap() map[any]string {
params := make(map[any]string)
for _, p := range p {
params[p.Value] = p.Name
}
return params
}
// AsMapByOrderedKeys returns a map of a key's position to it's value, as neccesary for Spanner PSQL.
// Example { $1 -> "value1", $2 -> "value2" }
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {

View File

@@ -63,6 +63,8 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
switch {
case strings.EqualFold(toolKind, "postgres-sql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
case strings.EqualFold(toolKind, "mssql"):
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
}
@@ -127,8 +129,14 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
t.Fatalf("error getting Google ID token: %s", err)
}
// Create wanted string
wantResult := fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
default:
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
}
// Test tool invocation with authenticated parameters
invokeTcs := []struct {
@@ -145,7 +153,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: false,
want: wantResult,
want: wantString,
},
{
name: "Invoke my-auth-tool with invalid auth token",
@@ -205,6 +213,15 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
}
func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
default:
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
}
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
@@ -271,7 +288,7 @@ func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any
requestHeader: map[string]string{"my-google-auth_token": idToken},
requestBody: bytes.NewBuffer([]byte(`{}`)),
isErr: false,
want: "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
want: wantString,
},
{
name: "Invoke my-auth-tool with invalid auth token",

View File

@@ -0,0 +1,376 @@
//go:build integration && cloudsqlmssql
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tests
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"reflect"
"regexp"
"slices"
"strings"
"testing"
"time"
"cloud.google.com/go/cloudsqlconn"
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
"github.com/google/uuid"
)
var (
CLOUD_SQL_MSSQL_PROJECT = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
CLOUD_SQL_MSSQL_REGION = os.Getenv("CLOUD_SQL_MSSQL_REGION")
CLOUD_SQL_MSSQL_INSTANCE = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
CLOUD_SQL_MSSQL_DATABASE = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
CLOUD_SQL_MSSQL_IP = os.Getenv("CLOUD_SQL_MSSQL_IP")
CLOUD_SQL_MSSQL_USER = os.Getenv("CLOUD_SQL_MSSQL_USER")
CLOUD_SQL_MSSQL_PASS = os.Getenv("CLOUD_SQL_MSSQL_PASS")
)
func requireCloudSQLMssqlVars(t *testing.T) map[string]any {
switch "" {
case CLOUD_SQL_MSSQL_PROJECT:
t.Fatal("'CLOUD_SQL_MSSQL_PROJECT' not set")
case CLOUD_SQL_MSSQL_REGION:
t.Fatal("'CLOUD_SQL_MSSQL_REGION' not set")
case CLOUD_SQL_MSSQL_INSTANCE:
t.Fatal("'CLOUD_SQL_MSSQL_INSTANCE' not set")
case CLOUD_SQL_MSSQL_IP:
t.Fatal("'CLOUD_SQL_MSSQL_IP' not set")
case CLOUD_SQL_MSSQL_DATABASE:
t.Fatal("'CLOUD_SQL_MSSQL_DATABASE' not set")
case CLOUD_SQL_MSSQL_USER:
t.Fatal("'CLOUD_SQL_MSSQL_USER' not set")
case CLOUD_SQL_MSSQL_PASS:
t.Fatal("'CLOUD_SQL_MSSQL_PASS' not set")
}
return map[string]any{
"kind": "cloud-sql-mssql",
"project": CLOUD_SQL_MSSQL_PROJECT,
"instance": CLOUD_SQL_MSSQL_INSTANCE,
"ipType": "public",
"ipAddress": CLOUD_SQL_MSSQL_IP,
"region": CLOUD_SQL_MSSQL_REGION,
"database": CLOUD_SQL_MSSQL_DATABASE,
"user": CLOUD_SQL_MSSQL_USER,
"password": CLOUD_SQL_MSSQL_PASS,
}
}
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
switch strings.ToLower(ipType) {
case "private":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
case "public":
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
default:
return nil, fmt.Errorf("invalid ipType %s", ipType)
}
}
// Copied over from cloud_sql_mssql.go
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
// Create dsn
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
// Get dial options
dialOpts, err := getDialOpts(ipType)
if err != nil {
return nil, err
}
// Register sql server driver
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
if err != nil {
return nil, err
}
}
// Open database connection
db, err := sql.Open(
"cloudsql-sqlserver-driver",
dsn,
)
if err != nil {
return nil, err
}
return db, nil
}
func TestCloudSQLMssql(t *testing.T) {
sourceConfig := requireCloudSQLMssqlVars(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"tools": map[string]any{
"my-simple-tool": map[string]any{
"kind": "mssql",
"source": "my-instance",
"description": "Simple tool to test end to end functionality.",
"statement": "SELECT 1;",
},
},
}
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
// Test tool get endpoint
tcs := []struct {
name string
api string
want map[string]any
}{
{
name: "get my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
want: map[string]any{
"my-simple-tool": map[string]any{
"description": "Simple tool to test end to end functionality.",
"parameters": []any{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Get(tc.api)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["tools"]
if !ok {
t.Fatalf("unable to find tools in response body")
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %q, want %q", got, tc.want)
}
})
}
// Test tool invoke endpoint
invokeTcs := []struct {
name string
api string
requestBody io.Reader
want string
}{
{
name: "invoke my-simple-tool",
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{}`)),
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
if err != nil {
t.Fatalf("error when sending a request: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
}
var body map[string]interface{}
err = json.NewDecoder(resp.Body).Decode(&body)
if err != nil {
t.Fatalf("error parsing response body")
}
got, ok := body["result"].(string)
if !ok {
t.Fatalf("unable to find result in response body")
}
if got != tc.want {
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
}
})
}
}
// Set up tool calling with parameters test table
func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) {
// Set up Tool invocation with parameters test
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
_, err = db.Query(fmt.Sprintf(`
CREATE TABLE %s (
id INT IDENTITY(1,1) PRIMARY KEY,
name VARCHAR(255),
);
`, tableName))
if err != nil {
return nil, err
}
// Insert test data
statement := fmt.Sprintf(`
INSERT INTO %s (name)
VALUES (@alice), (@jane), (@sid);
`, tableName)
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid")}
_, err = db.Query(statement, params...)
if err != nil {
return nil, err
}
return func(t *testing.T) {
// tear down test
_, err := db.Exec(fmt.Sprintf(`DROP TABLE %s;`, tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}, nil
}
func TestToolInvocationWithParams(t *testing.T) {
// create source config
sourceConfig := requireCloudSQLMssqlVars(t)
// create table name with UUID
tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
// test setup function reterns teardown function
teardownTest, err := setupParamTest(t, tableName)
if err != nil {
t.Fatalf("Unable to set up auth test: %s", err)
}
defer teardownTest(t)
// call generic invocation test helper
RunToolInvocationWithParamsTest(t, sourceConfig, "mssql", tableName)
}
// Set up auth test database table
func setupCloudSQLMssqlAuthTest(t *testing.T, tableName string) (func(*testing.T), error) {
// set up testt
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
_, err = db.Query(fmt.Sprintf(`
CREATE TABLE %s (
id INT IDENTITY(1,1) PRIMARY KEY,
name VARCHAR(255),
email VARCHAR(255)
);
`, tableName))
if err != nil {
return nil, err
}
// Insert test data
statement := fmt.Sprintf(`
INSERT INTO %s (name, email)
VALUES (@alice, @aliceemail), (@jane, @janeemail);
`, tableName)
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", SERVICE_ACCOUNT_EMAIL), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
_, err = db.Query(statement, params...)
if err != nil {
return nil, err
}
return func(t *testing.T) {
// tear down test
_, err := db.Exec(fmt.Sprintf(`DROP TABLE %s;`, tableName))
if err != nil {
t.Errorf("Teardown failed: %s", err)
}
}, nil
}
func TestCloudSQLMssqlGoogleAuthenticatedParameter(t *testing.T) {
// create test configs
sourceConfig := requireCloudSQLMssqlVars(t)
// create table name with UUID
tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
// test setup function reterns teardown function
teardownTest, err := setupCloudSQLMssqlAuthTest(t, tableName)
if err != nil {
t.Fatalf("Unable to set up auth test: %s", err)
}
defer teardownTest(t)
// call generic auth test helper
RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mssql", tableName)
}
func TestCloudSQLMssqlAuthRequiredToolInvocation(t *testing.T) {
// create test configs
sourceConfig := requireCloudSQLMssqlVars(t)
// call generic auth test helper
RunAuthRequiredToolInvocationTest(t, sourceConfig, "mssql")
}

View File

@@ -242,7 +242,7 @@ func TestCloudSQLPostgres(t *testing.T) {
}
// Set up auth test database table
func setupAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
func setupCloudSQLPgAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
// set up testt
pool, err := initCloudSQLPgConnectionPool(CLOUD_SQL_POSTGRES_PROJECT, CLOUD_SQL_POSTGRES_REGION, CLOUD_SQL_POSTGRES_INSTANCE, "public", CLOUD_SQL_POSTGRES_USER, CLOUD_SQL_POSTGRES_PASS, CLOUD_SQL_POSTGRES_DATABASE)
if err != nil {
@@ -285,7 +285,7 @@ func setupAuthTest(t *testing.T, ctx context.Context, tableName string) func(*te
}
}
func TestCloudSQLGoogleAuthenticatedParameter(t *testing.T) {
func TestCloudSQLPgGoogleAuthenticatedParameter(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
@@ -296,7 +296,7 @@ func TestCloudSQLGoogleAuthenticatedParameter(t *testing.T) {
tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
// test setup function reterns teardown function
teardownTest := setupAuthTest(t, ctx, tableName)
teardownTest := setupCloudSQLPgAuthTest(t, ctx, tableName)
defer teardownTest(t)
// call generic auth test helper

View File

@@ -210,15 +210,27 @@ func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string,
}
func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any, toolKind string, tableName string) {
// Write config into a file and pass it to command
// Specify query statement for different tool kinds
var statement string
switch toolKind {
case "postgres-sql":
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
case "mssql":
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
default:
t.Fatalf("invalid tool kind: %s", toolKind)
}
// Tools using database/sql interface only outputs `int64` instead of `int32`
var wantString string
switch toolKind {
case "mssql":
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int64=1) Alice][%!s(int64=3) Sid]"
default:
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]"
}
// Write config into a file and pass it to command
toolsFile := map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
@@ -280,7 +292,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
isErr: false,
want: "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]",
want: wantString,
},
{
name: "Invoke my-tool without parameters",