mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 15:38:08 -05:00
feat: add Spanner source and tool (#90)
Add Spanner source and tool.
Spanner source is initialize with the following config:
```
sources:
my-spanner-source:
kind: spanner
project: my-project-name
instance: my-instance-name
database: my_db
# dialect: postgresql # The default dialect is google_standard_sql.
```
Spanner tool (with gsql dialect) is initialize with the following
config.
```
tools:
get_flight_by_id:
kind: spanner
source: my-cloud-sql-source
description: >
Use this tool to list all airports matching search criteria. Takes
at least one of country, city, name, or all and returns all matching
airports. The agent can decide to return the results directly to
the user.
statement: "SELECT * FROM flights WHERE id = @id"
parameters:
- name: id
type: int
description: 'id' represents the unique ID for each flight.
```
Spanner tool (with postgresql dialect) is initialize with the following
config.
```
tools:
get_flight_by_id:
kind: spanner
source: my-cloud-sql-source
description: >
Use this tool to list all airports matching search criteria. Takes
at least one of country, city, name, or all and returns all matching
airports. The agent can decide to return the results directly to
the user.
statement: "SELECT * FROM flights WHERE id = $1"
parameters:
- name: id
type: int
description: 'id' represents the unique ID for each flight.
```
Note: the only difference in config for both dialects is the sql
statement.
---------
Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
13
go.mod
13
go.mod
@@ -5,23 +5,32 @@ go 1.22.2
|
||||
require (
|
||||
cloud.google.com/go/alloydbconn v1.13.0
|
||||
cloud.google.com/go/cloudsqlconn v1.12.1
|
||||
cloud.google.com/go/spanner v1.67.0
|
||||
github.com/go-chi/chi/v5 v5.1.0
|
||||
github.com/go-chi/httplog/v2 v2.1.1
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.1
|
||||
github.com/spf13/cobra v1.8.1
|
||||
google.golang.org/api v0.199.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
cel.dev/expr v0.16.0 // indirect
|
||||
cloud.google.com/go v0.115.1 // indirect
|
||||
cloud.google.com/go/alloydb v1.12.1 // indirect
|
||||
cloud.google.com/go/auth v0.9.5 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.5.2 // indirect
|
||||
cloud.google.com/go/longrunning v0.6.0 // indirect
|
||||
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 // indirect
|
||||
github.com/envoyproxy/go-control-plane v0.13.0 // indirect
|
||||
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -34,7 +43,7 @@ require (
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
|
||||
github.com/rogpeppe/go-internal v1.13.1 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
go.opencensus.io v0.24.0 // indirect
|
||||
@@ -50,7 +59,7 @@ require (
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
google.golang.org/api v0.199.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
|
||||
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect
|
||||
|
||||
@@ -21,8 +21,10 @@ import (
|
||||
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -138,6 +140,12 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case spannersrc.SourceKind:
|
||||
actual := spannersrc.Config{Name: name, Dialect: "google_standard_sql"}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
@@ -175,6 +183,12 @@ func (c *ToolConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case spanner.ToolKind:
|
||||
actual := spanner.Config{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
46
internal/sources/dialect.go
Normal file
46
internal/sources/dialect.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Dialect represents the dialect type of a database.
|
||||
type Dialect string
|
||||
|
||||
func (i *Dialect) String() string {
|
||||
if string(*i) != "" {
|
||||
return strings.ToLower(string(*i))
|
||||
}
|
||||
return "google_standard_sql"
|
||||
}
|
||||
|
||||
func (i *Dialect) UnmarshalYAML(node *yaml.Node) error {
|
||||
var dialect string
|
||||
if err := node.Decode(&dialect); err != nil {
|
||||
return err
|
||||
}
|
||||
switch strings.ToLower(dialect) {
|
||||
case "google_standard_sql", "postgresql":
|
||||
*i = Dialect(strings.ToLower(dialect))
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf(`dialect invalid: must be one of "google_standard_sql", or "postgresql"`)
|
||||
}
|
||||
}
|
||||
100
internal/sources/spanner/spanner.go
Normal file
100
internal/sources/spanner/spanner.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package spanner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
const SourceKind string = "spanner"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
Instance string `yaml:"instance"`
|
||||
Dialect sources.Dialect `yaml:"dialect"`
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
client, err := initSpannerClient(r.Project, r.Instance, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create client: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
Client: client,
|
||||
Dialect: r.Dialect.String(),
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Client *spanner.Client
|
||||
Dialect string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) SpannerClient() *spanner.Client {
|
||||
return s.Client
|
||||
}
|
||||
|
||||
func (s *Source) DatabaseDialect() string {
|
||||
return s.Dialect
|
||||
}
|
||||
|
||||
func initSpannerClient(project, instance, dbname string) (*spanner.Client, error) {
|
||||
// Configure the connection to the database
|
||||
db := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbname)
|
||||
|
||||
// Configure session pool to automatically clean inactive transactions
|
||||
sessionPoolConfig := spanner.SessionPoolConfig{
|
||||
TrackSessionHandles: true,
|
||||
InactiveTransactionRemovalOptions: spanner.InactiveTransactionRemovalOptions{
|
||||
ActionOnInactiveTransaction: spanner.WarnAndClose,
|
||||
},
|
||||
}
|
||||
|
||||
// Create spanner client
|
||||
ctx := context.Background()
|
||||
client, err := spanner.NewClientWithConfig(ctx, db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create new client: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
return client, nil
|
||||
}
|
||||
148
internal/sources/spanner/spanner_test.go
Normal file
148
internal/sources/spanner/spanner_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package spanner_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlSpannerDb(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-spanner-instance:
|
||||
kind: spanner
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-spanner-instance": spanner.Config{
|
||||
Name: "my-spanner-instance",
|
||||
Kind: spanner.SourceKind,
|
||||
Project: "my-project",
|
||||
Instance: "my-instance",
|
||||
Dialect: "google_standard_sql",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "gsql dialect",
|
||||
in: `
|
||||
sources:
|
||||
my-spanner-instance:
|
||||
kind: spanner
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
dialect: Google_standard_sql
|
||||
database: my_db
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-spanner-instance": spanner.Config{
|
||||
Name: "my-spanner-instance",
|
||||
Kind: spanner.SourceKind,
|
||||
Project: "my-project",
|
||||
Instance: "my-instance",
|
||||
Dialect: "google_standard_sql",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "postgresql dialect",
|
||||
in: `
|
||||
sources:
|
||||
my-spanner-instance:
|
||||
kind: spanner
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
dialect: postgresql
|
||||
database: my_db
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-spanner-instance": spanner.Config{
|
||||
Name: "my-spanner-instance",
|
||||
Kind: spanner.SourceKind,
|
||||
Project: "my-project",
|
||||
Instance: "my-instance",
|
||||
Dialect: "postgresql",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func FailParseFromYamlSpanner(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
}{
|
||||
{
|
||||
desc: "invalid dialect",
|
||||
in: `
|
||||
sources:
|
||||
my-spanner-instance:
|
||||
kind: spanner
|
||||
project: my-project
|
||||
instance: my-instance
|
||||
dialect: fail
|
||||
database: my_db
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -56,6 +56,18 @@ func (p ParamValues) AsMap() map[string]interface{} {
|
||||
return params
|
||||
}
|
||||
|
||||
// AsMapByOrderedKeys returns a map of a key's position to it's value, as neccesary for Spanner PSQL.
|
||||
// Example { $1 -> "value1", $2 -> "value2" }
|
||||
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
|
||||
params := make(map[string]interface{})
|
||||
|
||||
for i, p := range p {
|
||||
key := fmt.Sprintf("p%d", i+1)
|
||||
params[key] = p.Value
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// ParseParams parses specified Parameters from data and returns them as ParamValues.
|
||||
func ParseParams(ps Parameters, data map[string]any) (ParamValues, error) {
|
||||
params := make([]ParamValue, 0, len(ps))
|
||||
|
||||
@@ -256,22 +256,25 @@ func TestParametersParse(t *testing.T) {
|
||||
|
||||
func TestParamValues(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
in tools.ParamValues
|
||||
wantSlice []any
|
||||
wantMap map[string]interface{}
|
||||
name string
|
||||
in tools.ParamValues
|
||||
wantSlice []any
|
||||
wantMap map[string]interface{}
|
||||
wantMapOrdered map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
|
||||
wantSlice: []any{true, "hello world"},
|
||||
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
|
||||
name: "string",
|
||||
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
|
||||
wantSlice: []any{true, "hello world"},
|
||||
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
|
||||
wantMapOrdered: map[string]interface{}{"p1": true, "p2": "hello world"},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
gotSlice := tc.in.AsSlice()
|
||||
gotMap := tc.in.AsMap()
|
||||
gotMapOrdered := tc.in.AsMapByOrderedKeys()
|
||||
|
||||
for i, got := range gotSlice {
|
||||
want := tc.wantSlice[i]
|
||||
@@ -285,6 +288,12 @@ func TestParamValues(t *testing.T) {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
for i, got := range gotMapOrdered {
|
||||
want := tc.wantMapOrdered[i]
|
||||
if got != want {
|
||||
t.Fatalf("unexpected value: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
162
internal/tools/spanner/spanner.go
Normal file
162
internal/tools/spanner/spanner.go
Normal file
@@ -0,0 +1,162 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package spanner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cloud.google.com/go/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/iterator"
|
||||
)
|
||||
|
||||
const ToolKind string = "spanner-sql"
|
||||
|
||||
type compatibleSource interface {
|
||||
SpannerClient() *spanner.Client
|
||||
DatabaseDialect() string
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &spannerdb.Source{}
|
||||
|
||||
var compatibleSources = [...]string{spannerdb.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: ToolKind,
|
||||
Parameters: cfg.Parameters,
|
||||
Statement: cfg.Statement,
|
||||
Client: s.SpannerClient(),
|
||||
dialect: s.DatabaseDialect(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
|
||||
return Tool{
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
Client: client,
|
||||
dialect: dialect,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Client *spanner.Client
|
||||
dialect string
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func getMapParams(params tools.ParamValues, dialect string) (map[string]interface{}, error) {
|
||||
switch strings.ToLower(dialect) {
|
||||
case "google_standard_sql":
|
||||
return params.AsMap(), nil
|
||||
case "postgresql":
|
||||
return params.AsMapByOrderedKeys(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid dialect %s", dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
|
||||
mapParams, err := getMapParams(params, t.dialect)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("fail to get map params: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
var out strings.Builder
|
||||
|
||||
ctx := context.Background()
|
||||
_, err = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
|
||||
stmt := spanner.Statement{
|
||||
SQL: t.Statement,
|
||||
Params: mapParams,
|
||||
}
|
||||
iter := txn.Query(ctx, stmt)
|
||||
defer iter.Stop()
|
||||
|
||||
for {
|
||||
row, err := iter.Next()
|
||||
if err == iterator.Done {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
out.WriteString(row.String())
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute client: %w", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.Parameters, data)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
79
internal/tools/spanner/spanner_test.go
Normal file
79
internal/tools/spanner/spanner_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package spanner_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestParseFromYamlSpanner(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: spanner-sql
|
||||
source: my-pg-instance
|
||||
description: some description
|
||||
statement: |
|
||||
SELECT * FROM SQL_STATEMENT;
|
||||
parameters:
|
||||
- name: country
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": spanner.Config{
|
||||
Name: "example_tool",
|
||||
Kind: spanner.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
Parameters: []tools.Parameter{
|
||||
tools.NewStringParameter("country", "some description"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user