mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-12 00:49:08 -05:00
feat: Add Cloud SQL for SQL Server Source and Tool (#223)
1. `sql/database` provides a `Scan()`interface to scan query results into typed variables. Therefore we have to create a slice of typed variables (types retrieved from rows.ColumnTypes()) to pass them into `Scan()`. Using []byte works but makes the printing result different from other tools (e.g [1] instead of %!s(int32=1)] 2. MS SQL supports both named (e.g @name) and positional args (e.g @p2), so we have to check if the name is contained in the original statement before passing them into `db.Query()` as either named arg or as values.
This commit is contained in:
@@ -118,6 +118,27 @@ steps:
|
||||
- |
|
||||
go test -race -v -tags=integration,neo4j ./tests
|
||||
|
||||
- id: "cloud-sql-mssql"
|
||||
name: golang:1
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "CLOUD_SQL_MSSQL_PROJECT=$PROJECT_ID"
|
||||
- "CLOUD_SQL_MSSQL_INSTANCE=$_CLOUD_SQL_MSSQL_INSTANCE"
|
||||
- "CLOUD_SQL_MSSQL_IP=$_CLOUD_SQL_MSSQL_IP"
|
||||
- "CLOUD_SQL_MSSQL_DATABASE=$_DATABASE_NAME"
|
||||
- "CLOUD_SQL_MSSQL_REGION=$_REGION"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLOUD_SQL_MSSQL_USER", "CLOUD_SQL_MSSQL_PASS", "CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
go test -race -v -tags=integration,cloudsqlmssql ./tests
|
||||
|
||||
availableSecrets:
|
||||
secretManager:
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_pg_user/versions/latest
|
||||
@@ -138,6 +159,10 @@ availableSecrets:
|
||||
env: NEO4J_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/neo4j_pass/versions/latest
|
||||
env: NEO4J_PASS
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_user/versions/latest
|
||||
env: CLOUD_SQL_MSSQL_USER
|
||||
- versionName: projects/$PROJECT_ID/secrets/cloud_sql_mssql_pass/versions/latest
|
||||
env: CLOUD_SQL_MSSQL_PASS
|
||||
|
||||
options:
|
||||
logging: CLOUD_LOGGING_ONLY
|
||||
@@ -157,3 +182,4 @@ substitutions:
|
||||
_POSTGRES_PORT: "5432"
|
||||
_SPANNER_INSTANCE: "spanner-testing"
|
||||
_NEO4J_DATABASE: "neo4j"
|
||||
_CLOUD_SQL_MSSQL_INSTANCE: "cloud-sql-mssql-testing"
|
||||
|
||||
73
docs/en/resources/sources/cloud-sql-mssql.md
Normal file
73
docs/en/resources/sources/cloud-sql-mssql.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Cloud SQL for SQL Server Source
|
||||
|
||||
[Cloud SQL for SQL Server][csql-mssql-docs] is a managed database service that
|
||||
helps you set up, maintain, manage, and administer your SQL Server databases on
|
||||
Google Cloud.
|
||||
|
||||
If you are new to Cloud SQL for SQL Server, you can try [creating and connecting
|
||||
to a database by following these instructions][csql-mssql-connect].
|
||||
|
||||
[csql-mssql-docs]: https://cloud.google.com/sql/docs/sqlserver
|
||||
[csql-mssql-connect]: https://cloud.google.com/sql/docs/sqlserver/connect-overview
|
||||
|
||||
## Requirements
|
||||
|
||||
### IAM Identity
|
||||
|
||||
By default, this source uses the [Cloud SQL Go Connector][csql-go-conn] to
|
||||
authorize and establish mTLS connections to your Cloud SQL instance. The Go
|
||||
connector uses your [Application Default Credentials (ADC)][adc] to authorize
|
||||
your connection to Cloud SQL.
|
||||
|
||||
In addition to [setting the ADC for your server][set-adc], you need to ensure the
|
||||
IAM identity has been given the following IAM roles:
|
||||
|
||||
- `roles/cloudsql.client`
|
||||
|
||||
[csql-go-conn]: https://github.com/GoogleCloudPlatform/cloud-sql-go-connector
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[set-adc]: https://cloud.google.com/docs/authentication/provide-credentials-adc
|
||||
|
||||
### Network Path
|
||||
|
||||
Currently, Cloud SQL for SQL Server supports connection over both [private IP][private-ip] and
|
||||
[public IP][public-ip]. Set the `ipType` parameter in your source
|
||||
configuration to `public` or `private`.
|
||||
|
||||
[private-ip]: https://cloud.google.com/sql/docs/sqlserver/configure-private-ip
|
||||
[public-ip]: https://cloud.google.com/sql/docs/sqlserver/configure-ip
|
||||
|
||||
### Database User
|
||||
|
||||
Currently, this source only uses standard authentication. You will need to [create a
|
||||
SQL Server user][cloud-sql-users] to login to the database with.
|
||||
|
||||
[cloud-sql-users]: https://cloud.google.com/sql/docs/sqlserver/create-manage-users
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-cloud-sql-mssql-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
ipAddress: localhost
|
||||
ipType: public
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "cloud-sql-postgres". |
|
||||
| project | string | true | Id of the GCP project that the cluster was created in (e.g. "my-project-id"). |
|
||||
| region | string | true | Name of the GCP region that the cluster was created in (e.g. "us-central1"). |
|
||||
| instance | string | true | Name of the Cloud SQL instance within the cluser (e.g. "my-instance"). |
|
||||
| database | string | true | Name of the Cloud SQL database to connect to (e.g. "my_db"). |
|
||||
| ipAddress | string | true | IP address of the Cloud SQL instance to connect to.|
|
||||
| ipType | string | true | IP Type of the Cloud SQL instance, must be either `public` or `private`. Default: `public`. |
|
||||
| user | string | true | Name of the Postgres user to connect as (e.g. "my-pg-user"). |
|
||||
| password | string | true | Password of the Postgres user (e.g. "my-password").
|
||||
66
docs/tools/mssql.md
Normal file
66
docs/tools/mssql.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Cloud SQL Mssql Tool
|
||||
|
||||
A "mssql" tool executes a pre-defined SQL statement against a Cloud SQL for SQL Server
|
||||
database. It's compatible with any of the following sources:
|
||||
|
||||
- [cloud-sql-mssql](../sources/cloud-sql-mssql.md)
|
||||
|
||||
Toolbox supports the [prepare statement syntax][prepare-statement] of MS SQL
|
||||
Server and expects parameters in the SQL query to be in the form of either @Name
|
||||
or @p1 to @pN (ordinal position).
|
||||
|
||||
```sql
|
||||
db.QueryContext(ctx, `select * from t where ID = @ID and Name = @p2;`, sql.Named("ID", 6), "Bob")
|
||||
```
|
||||
|
||||
[prepare-statement]: https://learn.microsoft.com/sql/relational-databases/system-stored-procedures/sp-prepare-transact-sql?view=sql-server-ver16
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
search_flights_by_number:
|
||||
kind: mssql
|
||||
source: my-instance
|
||||
statement: |
|
||||
SELECT * FROM flights
|
||||
WHERE airline = @airline
|
||||
AND flight_number = @number
|
||||
LIMIT 10
|
||||
description: |
|
||||
Use this tool to get information for a specific flight.
|
||||
Takes an airline code and flight number and returns info on the flight.
|
||||
Do NOT use this tool with a flight id. Do NOT guess an airline code or flight number.
|
||||
A airline code is a code for an airline service consisting of two-character
|
||||
airline designator and followed by flight number, which is 1 to 4 digit number.
|
||||
For example, if given CY 0123, the airline is "CY", and flight_number is "123".
|
||||
Another example for this is DL 1234, the airline is "DL", and flight_number is "1234".
|
||||
If the tool returns more than one option choose the date closes to today.
|
||||
Example:
|
||||
{{
|
||||
"airline": "CY",
|
||||
"flight_number": "888",
|
||||
}}
|
||||
Example:
|
||||
{{
|
||||
"airline": "DL",
|
||||
"flight_number": "1234",
|
||||
}}
|
||||
parameters:
|
||||
- name: airline
|
||||
type: string
|
||||
description: Airline unique 2 letter identifier
|
||||
- name: number
|
||||
type: string
|
||||
description: 1 to 4 digit number
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:--------------------------------------------:|:------------:|-----------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "mssql". |
|
||||
| source | string | true | Name of the source the T-SQL statement should execute on.|
|
||||
| description | string | true | Description of the tool that is passed to the LLM|
|
||||
| statement | string | true | SQL statement to execute. |
|
||||
| parameters | [parameter](README.md#specifying-parameters) | true | List of [parameters](README.md#specifying-parameters) that will be inserted into the SQL statement. |
|
||||
3
go.mod
3
go.mod
@@ -56,6 +56,8 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/s2a-go v0.1.8 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
|
||||
@@ -65,6 +67,7 @@ require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/microsoft/go-mssqldb v1.8.0 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
|
||||
18
go.sum
18
go.sum
@@ -625,6 +625,18 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
|
||||
git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0 h1:U2rTu3Ef+7w9FHKIAXM6ZyqF3UOWJZ12zIm8zECAFfg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0 h1:jBQA3cKT4L2rWMpgE7Yt3Hwh2aUj8KXjIGLxjHeYNNo=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1 h1:MyVTgWR8qd/Jw1Le0NZebGBUCLbtak3bJ3z1OlqZBpw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0 h1:D3occbWoio4EBLkbkevetNMAVX197GkzbUMtqjGWn80=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.2 h1:DBjmt6/otSdULyJdVg2BlG0qGZO5tKL4VzOs0jpvw5Q=
|
||||
@@ -743,6 +755,8 @@ github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqw
|
||||
github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-yaml v1.15.13 h1:Xd87Yddmr2rC1SLLTm2MNDcTjeO/GYo0JGiww6gSTDg=
|
||||
github.com/goccy/go-yaml v1.15.13/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
@@ -910,6 +924,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/lyft/protoc-gen-star v0.6.0/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA=
|
||||
github.com/lyft/protoc-gen-star v0.6.1/go.mod h1:TGAoBVkt8w7MPG72TrKIu85MIdXwDuzJYeZuUPFPNwA=
|
||||
github.com/lyft/protoc-gen-star/v2 v2.0.1/go.mod h1:RcCdONR2ScXaYnQC5tUzxzlpA3WVYF7/opLeUgcQs/o=
|
||||
@@ -926,6 +942,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
|
||||
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
|
||||
github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
||||
@@ -22,12 +22,14 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
cloudsqlmssqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
cloudsqlmysqlsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
neo4jrc "github.com/googleapis/genai-toolbox/internal/sources/neo4j"
|
||||
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql"
|
||||
neo4jtool "github.com/googleapis/genai-toolbox/internal/tools/neo4j"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
|
||||
@@ -173,6 +175,12 @@ func (c *SourceConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case cloudsqlmssqlsrc.SourceKind:
|
||||
actual := cloudsqlmssqlsrc.Config{Name: name}
|
||||
if err := u.Unmarshal(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
@@ -264,6 +272,12 @@ func (c *ToolConfigs) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case mssql.ToolKind:
|
||||
actual := mssql.Config{Name: name}
|
||||
if err := u.Unmarshal(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
136
internal/sources/cloudsqlmssql/cloud_sql_mssql.go
Normal file
136
internal/sources/cloudsqlmssql/cloud_sql_mssql.go
Normal file
@@ -0,0 +1,136 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlmssql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-mssql"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
// Cloud SQL MSSQL configs
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
Region string `yaml:"region"`
|
||||
Instance string `yaml:"instance"`
|
||||
IPAddress string `yaml:"ipAddress"`
|
||||
IPType string `yaml:"ipType"`
|
||||
User string `yaml:"user"`
|
||||
Password string `yaml:"password"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
// Initializes a Cloud SQL MSSQL source
|
||||
db, err := initCloudSQLMssqlConnection(ctx, tracer, r.Name, r.Project, r.Region, r.Instance, r.IPAddress, r.IPType, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create db connection: %w", err)
|
||||
}
|
||||
|
||||
// Verify db connection
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Db: db,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// Cloud SQL MSSQL struct with connection pool
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Db *sql.DB
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
// Returns Cloud SQL MSSQL source kind
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) MSSQLDB() *sql.DB {
|
||||
// Returns a Cloud SQL MSSQL database connection pool
|
||||
return s.Db
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
func initCloudSQLMssqlConnection(ctx context.Context, tracer trace.Tracer, name, project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
//nolint:all // Reassigned ctx
|
||||
ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name)
|
||||
defer span.End()
|
||||
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
|
||||
// Get dial options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register sql server driver
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
|
||||
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open(
|
||||
"cloudsql-sqlserver-driver",
|
||||
dsn,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
76
internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go
Normal file
76
internal/sources/cloudsqlmssql/cloud_sql_mssql_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlmssql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudSQLMssql(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-instance:
|
||||
kind: cloud-sql-mssql
|
||||
project: my-project
|
||||
region: my-region
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
ipAddress: localhost
|
||||
ipType: public
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-instance": cloudsqlmssql.Config{
|
||||
Name: "my-instance",
|
||||
Kind: cloudsqlmssql.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
IPAddress: "localhost",
|
||||
IPType: "public",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect psarse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
168
internal/tools/mssql/mssql.go
Normal file
168
internal/tools/mssql/mssql.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mssql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
const ToolKind string = "mssql"
|
||||
|
||||
type compatibleSource interface {
|
||||
MSSQLDB() *sql.DB
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &cloudsqlmssql.Source{}
|
||||
|
||||
var compatibleSources = [...]string{cloudsqlmssql.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.MSSQLDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func NewGenericTool(name string, stmt string, authRequired []string, desc string, Db *sql.DB, parameters tools.Parameters) Tool {
|
||||
return Tool{
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
AuthRequired: authRequired,
|
||||
Db: Db,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Db *sql.DB
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
|
||||
namedArgs := make([]any, 0, len(params))
|
||||
paramsMap := params.AsReversedMap()
|
||||
// To support both named args (e.g @id) and positional args (e.g @p1), check if arg name is contained in the statement.
|
||||
for _, v := range params.AsSlice() {
|
||||
paramName := paramsMap[v]
|
||||
if strings.Contains(t.Statement, "@"+paramName) {
|
||||
namedArgs = append(namedArgs, sql.Named(paramName, v))
|
||||
} else {
|
||||
namedArgs = append(namedArgs, v)
|
||||
}
|
||||
}
|
||||
rows, err := t.Db.Query(t.Statement, namedArgs...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
types, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to fetch column types: %w", err)
|
||||
}
|
||||
v := make([]any, len(types))
|
||||
pointers := make([]any, len(types))
|
||||
for i := range types {
|
||||
pointers[i] = &v[i]
|
||||
}
|
||||
|
||||
// fetch result into a string
|
||||
var out strings.Builder
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(pointers...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("%s", v))
|
||||
}
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to close rows: %w", err)
|
||||
}
|
||||
|
||||
// Check if error occured during iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthSources []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthSources)
|
||||
}
|
||||
90
internal/tools/mssql/mssql_test.go
Normal file
90
internal/tools/mssql/mssql_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package mssql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mssql"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlMssql(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: mssql
|
||||
source: my-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
authSources:
|
||||
- name: my-google-auth-service
|
||||
field: user_id
|
||||
- name: other-auth-service
|
||||
field: user_id
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": mssql.Config{
|
||||
Name: "example_tool",
|
||||
Kind: mssql.ToolKind,
|
||||
Source: "my-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameterWithAuth("country", "some description",
|
||||
[]tools.ParamAuthSource{{Name: "my-google-auth-service", Field: "user_id"},
|
||||
{Name: "other-auth-service", Field: "user_id"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -57,6 +57,15 @@ func (p ParamValues) AsMap() map[string]interface{} {
|
||||
return params
|
||||
}
|
||||
|
||||
// AsReversedMap returns a map of ParamValue's values to names.
|
||||
func (p ParamValues) AsReversedMap() map[any]string {
|
||||
params := make(map[any]string)
|
||||
for _, p := range p {
|
||||
params[p.Value] = p.Name
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// AsMapByOrderedKeys returns a map of a key's position to it's value, as neccesary for Spanner PSQL.
|
||||
// Example { $1 -> "value1", $2 -> "value2" }
|
||||
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
|
||||
|
||||
@@ -63,6 +63,8 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
switch {
|
||||
case strings.EqualFold(toolKind, "postgres-sql"):
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = $1;", tableName)
|
||||
case strings.EqualFold(toolKind, "mssql"):
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE email = @email;", tableName)
|
||||
default:
|
||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||
}
|
||||
@@ -127,8 +129,14 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
t.Fatalf("error getting Google ID token: %s", err)
|
||||
}
|
||||
|
||||
// Create wanted string
|
||||
wantResult := fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
||||
var wantString string
|
||||
switch toolKind {
|
||||
case "mssql":
|
||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int64=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
||||
default:
|
||||
wantString = fmt.Sprintf("Stub tool call for \"my-auth-tool\"! Parameters parsed: [{\"email\" \"%s\"}] \n Output: [%%!s(int32=1) Alice %s]", SERVICE_ACCOUNT_EMAIL, SERVICE_ACCOUNT_EMAIL)
|
||||
}
|
||||
|
||||
// Test tool invocation with authenticated parameters
|
||||
invokeTcs := []struct {
|
||||
@@ -145,7 +153,7 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: false,
|
||||
want: wantResult,
|
||||
want: wantString,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
@@ -205,6 +213,15 @@ func RunGoogleAuthenticatedParameterTest(t *testing.T, sourceConfig map[string]a
|
||||
}
|
||||
|
||||
func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any, toolKind string) {
|
||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
||||
var wantString string
|
||||
switch toolKind {
|
||||
case "mssql":
|
||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]"
|
||||
default:
|
||||
wantString = "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]"
|
||||
}
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
@@ -271,7 +288,7 @@ func RunAuthRequiredToolInvocationTest(t *testing.T, sourceConfig map[string]any
|
||||
requestHeader: map[string]string{"my-google-auth_token": idToken},
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
isErr: false,
|
||||
want: "Stub tool call for \"my-auth-tool\"! Parameters parsed: [] \n Output: [%!s(int32=1)]",
|
||||
want: wantString,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-auth-tool with invalid auth token",
|
||||
|
||||
376
tests/cloud_sql_mssql_integration_test.go
Normal file
376
tests/cloud_sql_mssql_integration_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
//go:build integration && cloudsqlmssql
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"cloud.google.com/go/cloudsqlconn/sqlserver/mssql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
CLOUD_SQL_MSSQL_PROJECT = os.Getenv("CLOUD_SQL_MSSQL_PROJECT")
|
||||
CLOUD_SQL_MSSQL_REGION = os.Getenv("CLOUD_SQL_MSSQL_REGION")
|
||||
CLOUD_SQL_MSSQL_INSTANCE = os.Getenv("CLOUD_SQL_MSSQL_INSTANCE")
|
||||
CLOUD_SQL_MSSQL_DATABASE = os.Getenv("CLOUD_SQL_MSSQL_DATABASE")
|
||||
CLOUD_SQL_MSSQL_IP = os.Getenv("CLOUD_SQL_MSSQL_IP")
|
||||
CLOUD_SQL_MSSQL_USER = os.Getenv("CLOUD_SQL_MSSQL_USER")
|
||||
CLOUD_SQL_MSSQL_PASS = os.Getenv("CLOUD_SQL_MSSQL_PASS")
|
||||
)
|
||||
|
||||
func requireCloudSQLMssqlVars(t *testing.T) map[string]any {
|
||||
switch "" {
|
||||
case CLOUD_SQL_MSSQL_PROJECT:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_PROJECT' not set")
|
||||
case CLOUD_SQL_MSSQL_REGION:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_REGION' not set")
|
||||
case CLOUD_SQL_MSSQL_INSTANCE:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_INSTANCE' not set")
|
||||
case CLOUD_SQL_MSSQL_IP:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_IP' not set")
|
||||
case CLOUD_SQL_MSSQL_DATABASE:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_DATABASE' not set")
|
||||
case CLOUD_SQL_MSSQL_USER:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_USER' not set")
|
||||
case CLOUD_SQL_MSSQL_PASS:
|
||||
t.Fatal("'CLOUD_SQL_MSSQL_PASS' not set")
|
||||
}
|
||||
|
||||
return map[string]any{
|
||||
"kind": "cloud-sql-mssql",
|
||||
"project": CLOUD_SQL_MSSQL_PROJECT,
|
||||
"instance": CLOUD_SQL_MSSQL_INSTANCE,
|
||||
"ipType": "public",
|
||||
"ipAddress": CLOUD_SQL_MSSQL_IP,
|
||||
"region": CLOUD_SQL_MSSQL_REGION,
|
||||
"database": CLOUD_SQL_MSSQL_DATABASE,
|
||||
"user": CLOUD_SQL_MSSQL_USER,
|
||||
"password": CLOUD_SQL_MSSQL_PASS,
|
||||
}
|
||||
}
|
||||
|
||||
func getDialOpts(ipType string) ([]cloudsqlconn.DialOption, error) {
|
||||
switch strings.ToLower(ipType) {
|
||||
case "private":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPrivateIP()}, nil
|
||||
case "public":
|
||||
return []cloudsqlconn.DialOption{cloudsqlconn.WithPublicIP()}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid ipType %s", ipType)
|
||||
}
|
||||
}
|
||||
|
||||
// Copied over from cloud_sql_mssql.go
|
||||
func initCloudSQLMssqlConnection(project, region, instance, ipAddress, ipType, user, pass, dbname string) (*sql.DB, error) {
|
||||
// Create dsn
|
||||
dsn := fmt.Sprintf("sqlserver://%s:%s@%s?database=%s&cloudsql=%s:%s:%s", user, pass, ipAddress, dbname, project, region, instance)
|
||||
|
||||
// Get dial options
|
||||
dialOpts, err := getDialOpts(ipType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register sql server driver
|
||||
if !slices.Contains(sql.Drivers(), "cloudsql-sqlserver-driver") {
|
||||
_, err := mssql.RegisterDriver("cloudsql-sqlserver-driver", cloudsqlconn.WithDefaultDialOptions(dialOpts...))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Open database connection
|
||||
db, err := sql.Open(
|
||||
"cloudsql-sqlserver-driver",
|
||||
dsn,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func TestCloudSQLMssql(t *testing.T) {
|
||||
sourceConfig := requireCloudSQLMssqlVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": "mssql",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
},
|
||||
}
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := cmd.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`))
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(tc.api)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "Stub tool call for \"my-simple-tool\"! Parameters parsed: [] \n Output: [%!s(int64=1)]",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Set up tool calling with parameters test table
|
||||
func setupParamTest(t *testing.T, tableName string) (func(*testing.T), error) {
|
||||
// Set up Tool invocation with parameters test
|
||||
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.Query(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id INT IDENTITY(1,1) PRIMARY KEY,
|
||||
name VARCHAR(255),
|
||||
);
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
statement := fmt.Sprintf(`
|
||||
INSERT INTO %s (name)
|
||||
VALUES (@alice), (@jane), (@sid);
|
||||
`, tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("jane", "Jane"), sql.Named("sid", "Sid")}
|
||||
_, err = db.Query(statement, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
_, err := db.Exec(fmt.Sprintf(`DROP TABLE %s;`, tableName))
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestToolInvocationWithParams(t *testing.T) {
|
||||
// create source config
|
||||
sourceConfig := requireCloudSQLMssqlVars(t)
|
||||
|
||||
// create table name with UUID
|
||||
tableName := "param_test_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest, err := setupParamTest(t, tableName)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to set up auth test: %s", err)
|
||||
}
|
||||
defer teardownTest(t)
|
||||
|
||||
// call generic invocation test helper
|
||||
RunToolInvocationWithParamsTest(t, sourceConfig, "mssql", tableName)
|
||||
}
|
||||
|
||||
// Set up auth test database table
|
||||
func setupCloudSQLMssqlAuthTest(t *testing.T, tableName string) (func(*testing.T), error) {
|
||||
// set up testt
|
||||
db, err := initCloudSQLMssqlConnection(CLOUD_SQL_MSSQL_PROJECT, CLOUD_SQL_MSSQL_REGION, CLOUD_SQL_MSSQL_INSTANCE, CLOUD_SQL_MSSQL_IP, "public", CLOUD_SQL_MSSQL_USER, CLOUD_SQL_MSSQL_PASS, CLOUD_SQL_MSSQL_DATABASE)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.Query(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id INT IDENTITY(1,1) PRIMARY KEY,
|
||||
name VARCHAR(255),
|
||||
email VARCHAR(255)
|
||||
);
|
||||
`, tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
statement := fmt.Sprintf(`
|
||||
INSERT INTO %s (name, email)
|
||||
VALUES (@alice, @aliceemail), (@jane, @janeemail);
|
||||
`, tableName)
|
||||
params := []any{sql.Named("alice", "Alice"), sql.Named("aliceemail", SERVICE_ACCOUNT_EMAIL), sql.Named("jane", "Jane"), sql.Named("janeemail", "janedoe@gmail.com")}
|
||||
_, err = db.Query(statement, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
_, err := db.Exec(fmt.Sprintf(`DROP TABLE %s;`, tableName))
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestCloudSQLMssqlGoogleAuthenticatedParameter(t *testing.T) {
|
||||
// create test configs
|
||||
sourceConfig := requireCloudSQLMssqlVars(t)
|
||||
|
||||
// create table name with UUID
|
||||
tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest, err := setupCloudSQLMssqlAuthTest(t, tableName)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to set up auth test: %s", err)
|
||||
}
|
||||
defer teardownTest(t)
|
||||
|
||||
// call generic auth test helper
|
||||
RunGoogleAuthenticatedParameterTest(t, sourceConfig, "mssql", tableName)
|
||||
|
||||
}
|
||||
|
||||
func TestCloudSQLMssqlAuthRequiredToolInvocation(t *testing.T) {
|
||||
// create test configs
|
||||
sourceConfig := requireCloudSQLMssqlVars(t)
|
||||
|
||||
// call generic auth test helper
|
||||
RunAuthRequiredToolInvocationTest(t, sourceConfig, "mssql")
|
||||
|
||||
}
|
||||
@@ -242,7 +242,7 @@ func TestCloudSQLPostgres(t *testing.T) {
|
||||
}
|
||||
|
||||
// Set up auth test database table
|
||||
func setupAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
|
||||
func setupCloudSQLPgAuthTest(t *testing.T, ctx context.Context, tableName string) func(*testing.T) {
|
||||
// set up testt
|
||||
pool, err := initCloudSQLPgConnectionPool(CLOUD_SQL_POSTGRES_PROJECT, CLOUD_SQL_POSTGRES_REGION, CLOUD_SQL_POSTGRES_INSTANCE, "public", CLOUD_SQL_POSTGRES_USER, CLOUD_SQL_POSTGRES_PASS, CLOUD_SQL_POSTGRES_DATABASE)
|
||||
if err != nil {
|
||||
@@ -285,7 +285,7 @@ func setupAuthTest(t *testing.T, ctx context.Context, tableName string) func(*te
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudSQLGoogleAuthenticatedParameter(t *testing.T) {
|
||||
func TestCloudSQLPgGoogleAuthenticatedParameter(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -296,7 +296,7 @@ func TestCloudSQLGoogleAuthenticatedParameter(t *testing.T) {
|
||||
tableName := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// test setup function reterns teardown function
|
||||
teardownTest := setupAuthTest(t, ctx, tableName)
|
||||
teardownTest := setupCloudSQLPgAuthTest(t, ctx, tableName)
|
||||
defer teardownTest(t)
|
||||
|
||||
// call generic auth test helper
|
||||
|
||||
@@ -210,15 +210,27 @@ func (c *CmdExec) WaitForString(ctx context.Context, re *regexp.Regexp) (string,
|
||||
}
|
||||
|
||||
func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any, toolKind string, tableName string) {
|
||||
// Write config into a file and pass it to command
|
||||
// Specify query statement for different tool kinds
|
||||
var statement string
|
||||
switch toolKind {
|
||||
case "postgres-sql":
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2;", tableName)
|
||||
case "mssql":
|
||||
statement = fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @p2;", tableName)
|
||||
default:
|
||||
t.Fatalf("invalid tool kind: %s", toolKind)
|
||||
}
|
||||
|
||||
// Tools using database/sql interface only outputs `int64` instead of `int32`
|
||||
var wantString string
|
||||
switch toolKind {
|
||||
case "mssql":
|
||||
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int64=1) Alice][%!s(int64=3) Sid]"
|
||||
default:
|
||||
wantString = "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]"
|
||||
}
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-instance": sourceConfig,
|
||||
@@ -280,7 +292,7 @@ func RunToolInvocationWithParamsTest(t *testing.T, sourceConfig map[string]any,
|
||||
api: "http://127.0.0.1:5000/api/tool/my-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"id": 3, "name": "Alice"}`)),
|
||||
isErr: false,
|
||||
want: "Stub tool call for \"my-tool\"! Parameters parsed: [{\"id\" '\\x03'} {\"name\" \"Alice\"}] \n Output: [%!s(int32=1) Alice][%!s(int32=3) Sid]",
|
||||
want: wantString,
|
||||
},
|
||||
{
|
||||
name: "Invoke my-tool without parameters",
|
||||
|
||||
Reference in New Issue
Block a user