mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
@@ -115,6 +115,8 @@ steps:
|
||||
- "SPANNER_PROJECT=$PROJECT_ID"
|
||||
- "SPANNER_DATABASE=$_DATABASE_NAME"
|
||||
- "SPANNER_INSTANCE=$_SPANNER_INSTANCE"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
|
||||
1
go.mod
1
go.mod
@@ -41,6 +41,7 @@ require (
|
||||
cloud.google.com/go/auth v0.15.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.6.0 // indirect
|
||||
cloud.google.com/go/iam v1.4.2 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.6 // indirect
|
||||
cloud.google.com/go/monitoring v1.24.1 // indirect
|
||||
cloud.google.com/go/trace v1.11.3 // indirect
|
||||
|
||||
@@ -249,3 +249,26 @@ func GetMysqlLAuthToolInfo(tableName string) (string, string, string, []any) {
|
||||
params := []any{"Alice", SERVICE_ACCOUNT_EMAIL, "Jane", "janedoe@gmail.com"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetSpannerToolInfo returns statements and param for my-param-tool for spanner-sql kind
|
||||
func GetSpannerParamToolInfo(tableName string) (string, string, string, map[string]any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, @name1), (2, @name2), (3, @name3)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT * FROM %s WHERE id = @id OR name = @name", tableName)
|
||||
params := map[string]any{"name1": "Alice", "name2": "Jane", "name3": "Sid"}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
// GetSpannerAuthToolInfo returns statements and param of my-auth-tool for spanner-sql kind
|
||||
func GetSpannerAuthToolInfo(tableName string) (string, string, string, map[string]any) {
|
||||
create_statement := fmt.Sprintf("CREATE TABLE %s (id INT64, name STRING(MAX), email STRING(MAX)) PRIMARY KEY (id)", tableName)
|
||||
insert_statement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, @name1, @email1), (2, @name2, @email2)", tableName)
|
||||
tool_statement := fmt.Sprintf("SELECT name FROM %s WHERE email = @email", tableName)
|
||||
params := map[string]any{
|
||||
"name1": "Alice",
|
||||
"email1": SERVICE_ACCOUNT_EMAIL,
|
||||
"name2": "Jane",
|
||||
"email2": "janedoe@gmail.com",
|
||||
}
|
||||
return create_statement, insert_statement, tool_statement, params
|
||||
}
|
||||
|
||||
@@ -17,16 +17,17 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
database "cloud.google.com/go/spanner/admin/database/apiv1"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -55,6 +56,33 @@ func getSpannerVars(t *testing.T) map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func initSpannerClients(ctx context.Context, project, instance, dbname string) (*spanner.Client, *database.DatabaseAdminClient, error) {
|
||||
// Configure the connection to the database
|
||||
db := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbname)
|
||||
|
||||
// Configure session pool to automatically clean inactive transactions
|
||||
sessionPoolConfig := spanner.SessionPoolConfig{
|
||||
TrackSessionHandles: true,
|
||||
InactiveTransactionRemovalOptions: spanner.InactiveTransactionRemovalOptions{
|
||||
ActionOnInactiveTransaction: spanner.WarnAndClose,
|
||||
},
|
||||
}
|
||||
|
||||
// Create Spanner client (for queries)
|
||||
dataClient, err := spanner.NewClientWithConfig(context.Background(), db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to create new Spanner client: %w", err)
|
||||
}
|
||||
|
||||
// Create Spanner admin client (for creating databases)
|
||||
adminClient, err := database.NewDatabaseAdminClient(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to create new Spanner admin client: %w", err)
|
||||
}
|
||||
|
||||
return dataClient, adminClient, nil
|
||||
}
|
||||
|
||||
func TestSpannerToolEndpoints(t *testing.T) {
|
||||
sourceConfig := getSpannerVars(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
@@ -62,20 +90,35 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
|
||||
var args []string
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-spanner-instance": sourceConfig,
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"kind": "spanner-sql",
|
||||
"source": "my-spanner-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"statement": "SELECT 1;",
|
||||
},
|
||||
},
|
||||
// Create Spanner client
|
||||
dataClient, adminClient, err := initSpannerClients(ctx, SPANNER_PROJECT, SPANNER_INSTANCE, SPANNER_DATABASE)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create Spanner client: %s", err)
|
||||
}
|
||||
|
||||
// create table name with UUID
|
||||
tableNameParam := "param_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
tableNameAuth := "auth_table_" + strings.Replace(uuid.New().String(), "-", "", -1)
|
||||
|
||||
// set up data for param tool
|
||||
create_statement1, insert_statement1, tool_statement1, params1 := GetSpannerParamToolInfo(tableNameParam)
|
||||
dbString := fmt.Sprintf(
|
||||
"projects/%s/instances/%s/databases/%s",
|
||||
SPANNER_PROJECT,
|
||||
SPANNER_INSTANCE,
|
||||
SPANNER_DATABASE,
|
||||
)
|
||||
teardownTable1 := SetupSpannerTable(t, ctx, adminClient, dataClient, create_statement1, insert_statement1, tableNameParam, dbString, params1)
|
||||
defer teardownTable1(t)
|
||||
|
||||
// set up data for auth tool
|
||||
create_statement2, insert_statement2, tool_statement2, params2 := GetSpannerAuthToolInfo(tableNameAuth)
|
||||
teardownTable2 := SetupSpannerTable(t, ctx, adminClient, dataClient, create_statement2, insert_statement2, tableNameAuth, dbString, params2)
|
||||
defer teardownTable2(t)
|
||||
|
||||
// Write config into a file and pass it to command
|
||||
toolsFile := GetToolsConfig(sourceConfig, SPANNER_TOOL_KIND, tool_statement1, tool_statement2)
|
||||
|
||||
cmd, cleanup, err := StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
@@ -90,88 +133,11 @@ func TestSpannerToolEndpoints(t *testing.T) {
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
// Test tool get endpoint
|
||||
tcs := []struct {
|
||||
name string
|
||||
api string
|
||||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "get my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/",
|
||||
want: map[string]any{
|
||||
"my-simple-tool": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"parameters": []any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Get(tc.api)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
RunToolGetTest(t)
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
select_1_want := "[{\"\":\"1\"}]"
|
||||
fail_invocation_want := `"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute client: unable to parse row: spanner: code = "InvalidArgument", desc = "Syntax error: Unexpected identifier \"SELEC\" [at 1:1]\nSELEC 1;\n^"`
|
||||
|
||||
got, ok := body["tools"]
|
||||
if !ok {
|
||||
t.Fatalf("unable to find tools in response body")
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test tool invoke endpoint
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
api string
|
||||
requestBody io.Reader
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "invoke my-simple-tool",
|
||||
api: "http://127.0.0.1:5000/api/tool/my-simple-tool/invoke",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
want: "[{\"\":\"1\"}]",
|
||||
},
|
||||
}
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, err := http.Post(tc.api, "application/json", tc.requestBody)
|
||||
if err != nil {
|
||||
t.Fatalf("error when sending a request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("response status code is not 200")
|
||||
}
|
||||
|
||||
var body map[string]interface{}
|
||||
err = json.NewDecoder(resp.Body).Decode(&body)
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing response body")
|
||||
}
|
||||
got, ok := body["result"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
if got != tc.want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
RunToolInvokeTest(t, select_1_want)
|
||||
RunMCPToolCallMethod(t, fail_invocation_want)
|
||||
}
|
||||
|
||||
@@ -28,6 +28,9 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
database "cloud.google.com/go/spanner/admin/database/apiv1"
|
||||
"cloud.google.com/go/spanner/admin/database/apiv1/databasepb"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
@@ -119,6 +122,54 @@ func SetupMySQLTable(t *testing.T, ctx context.Context, pool *sql.DB, create_sta
|
||||
}
|
||||
}
|
||||
|
||||
// SetupSpannerTable creates and inserts data into a table of tool
|
||||
// compatible with spanner-sql tool
|
||||
func SetupSpannerTable(t *testing.T, ctx context.Context, adminClient *database.DatabaseAdminClient, dataClient *spanner.Client, create_statement, insert_statement, tableName, dbString string, params map[string]any) func(*testing.T) {
|
||||
|
||||
// Create table
|
||||
op, err := adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
|
||||
Database: dbString,
|
||||
Statements: []string{create_statement},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to start create table operation %s: %s", tableName, err)
|
||||
}
|
||||
err = op.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test table %s: %s", tableName, err)
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
_, err = dataClient.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||
stmt := spanner.Statement{
|
||||
SQL: insert_statement,
|
||||
Params: params,
|
||||
}
|
||||
_, err := txn.Update(ctx, stmt)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to insert test data: %s", err)
|
||||
}
|
||||
|
||||
return func(t *testing.T) {
|
||||
// tear down test
|
||||
op, err = adminClient.UpdateDatabaseDdl(ctx, &databasepb.UpdateDatabaseDdlRequest{
|
||||
Database: dbString,
|
||||
Statements: []string{fmt.Sprintf("DROP TABLE %s", tableName)},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("unable to start drop table operation: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = op.Wait(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Teardown failed: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunToolGet runs the tool get endpoint
|
||||
func RunToolGetTest(t *testing.T) {
|
||||
// Test tool get endpoint
|
||||
@@ -294,11 +345,13 @@ func RunToolInvokeTest(t *testing.T, select_1_want string) {
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
// Remove `\` and `"` for string comparison
|
||||
got = strings.ReplaceAll(got, "\\", "")
|
||||
want := strings.ReplaceAll(tc.want, "\\", "")
|
||||
got = strings.ReplaceAll(got, "\"", "")
|
||||
want = strings.ReplaceAll(want, "\"", "")
|
||||
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
}
|
||||
@@ -431,9 +484,14 @@ func RunMCPToolCallMethod(t *testing.T, fail_invocation_want string) {
|
||||
defer resp.Body.Close()
|
||||
got := string(bytes.TrimSpace(respBody))
|
||||
|
||||
if got != tc.want {
|
||||
fmt.Printf("res is %s\n\n", got)
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, tc.want)
|
||||
// Remove `\` and `"` for string comparison
|
||||
got = strings.ReplaceAll(got, "\\", "")
|
||||
want := strings.ReplaceAll(tc.want, "\\", "")
|
||||
got = strings.ReplaceAll(got, "\"", "")
|
||||
want = strings.ReplaceAll(want, "\"", "")
|
||||
|
||||
if !strings.Contains(got, want) {
|
||||
t.Fatalf("Expected substring not found:\ngot: %q\nwant: %q (to be contained within got)", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user