feat: add Spanner source and tool (#90)

Add Spanner source and tool.

Spanner source is initialize with the following config:
```
sources:
    my-spanner-source:
        kind: spanner
        project: my-project-name
        instance: my-instance-name
        database: my_db
        # dialect: postgresql # The default dialect is google_standard_sql.
```

Spanner tool (with gsql dialect) is initialize with the following
config.
```
tools:
    get_flight_by_id:
        kind: spanner
        source: my-cloud-sql-source
        description: >
            Use this tool to list all airports matching search criteria. Takes 
            at least one of country, city, name, or all and returns all matching
            airports. The agent can decide to return the results directly to 
            the user.
        statement: "SELECT * FROM flights WHERE id = @id"
        parameters:
        - name: id
          type: int
          description: 'id' represents the unique ID for each flight. 
```

Spanner tool (with postgresql dialect) is initialize with the following
config.
```
tools:
    get_flight_by_id:
        kind: spanner
        source: my-cloud-sql-source
        description: >
            Use this tool to list all airports matching search criteria. Takes 
            at least one of country, city, name, or all and returns all matching
            airports. The agent can decide to return the results directly to 
            the user.
        statement: "SELECT * FROM flights WHERE id = $1"
        parameters:
        - name: id
          type: int
          description: 'id' represents the unique ID for each flight. 
```

Note: the only difference in config for both dialects is the sql
statement.

---------

Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com>
This commit is contained in:
Yuan
2024-12-06 16:38:03 -08:00
committed by GitHub
parent 5528bec8ed
commit 890914aae0
10 changed files with 2034 additions and 11 deletions

13
go.mod
View File

@@ -5,23 +5,32 @@ go 1.22.2
require (
cloud.google.com/go/alloydbconn v1.13.0
cloud.google.com/go/cloudsqlconn v1.12.1
cloud.google.com/go/spanner v1.67.0
github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/httplog/v2 v2.1.1
github.com/go-chi/render v1.0.3
github.com/google/go-cmp v0.6.0
github.com/jackc/pgx/v5 v5.7.1
github.com/spf13/cobra v1.8.1
google.golang.org/api v0.199.0
gopkg.in/yaml.v3 v3.0.1
)
require (
cel.dev/expr v0.16.0 // indirect
cloud.google.com/go v0.115.1 // indirect
cloud.google.com/go/alloydb v1.12.1 // indirect
cloud.google.com/go/auth v0.9.5 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.4 // indirect
cloud.google.com/go/compute/metadata v0.5.2 // indirect
cloud.google.com/go/longrunning v0.6.0 // indirect
github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 // indirect
github.com/ajg/form v1.5.1 // indirect
github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cncf/xds/go v0.0.0-20240723142845-024c85f92f20 // indirect
github.com/envoyproxy/go-control-plane v0.13.0 // indirect
github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
@@ -34,7 +43,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect
github.com/rogpeppe/go-internal v1.13.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect
@@ -50,7 +59,7 @@ require (
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/time v0.7.0 // indirect
google.golang.org/api v0.199.0 // indirect
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect
google.golang.org/genproto v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241007155032-5fefd90f89a9 // indirect

1446
go.sum

File diff suppressed because it is too large Load Diff

View File

@@ -21,8 +21,10 @@ import (
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
spannersrc "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/postgressql"
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
"gopkg.in/yaml.v3"
)
@@ -138,6 +140,12 @@ func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case spannersrc.SourceKind:
actual := spannersrc.Config{Name: name, Dialect: "google_standard_sql"}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
}
@@ -175,6 +183,12 @@ func (c *ToolConfigs) UnmarshalYAML(node *yaml.Node) error {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
case spanner.ToolKind:
actual := spanner.Config{Name: name}
if err := n.Decode(&actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
}

View File

@@ -0,0 +1,46 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sources
import (
"fmt"
"strings"
"gopkg.in/yaml.v3"
)
// Dialect represents the dialect type of a database.
type Dialect string
func (i *Dialect) String() string {
if string(*i) != "" {
return strings.ToLower(string(*i))
}
return "google_standard_sql"
}
func (i *Dialect) UnmarshalYAML(node *yaml.Node) error {
var dialect string
if err := node.Decode(&dialect); err != nil {
return err
}
switch strings.ToLower(dialect) {
case "google_standard_sql", "postgresql":
*i = Dialect(strings.ToLower(dialect))
return nil
default:
return fmt.Errorf(`dialect invalid: must be one of "google_standard_sql", or "postgresql"`)
}
}

View File

@@ -0,0 +1,100 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spanner
import (
"context"
"fmt"
"cloud.google.com/go/spanner"
"github.com/googleapis/genai-toolbox/internal/sources"
)
const SourceKind string = "spanner"
// validate interface
var _ sources.SourceConfig = Config{}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Project string `yaml:"project"`
Instance string `yaml:"instance"`
Dialect sources.Dialect `yaml:"dialect"`
Database string `yaml:"database"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
func (r Config) Initialize() (sources.Source, error) {
client, err := initSpannerClient(r.Project, r.Instance, r.Database)
if err != nil {
return nil, fmt.Errorf("unable to create client: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
Client: client,
Dialect: r.Dialect.String(),
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Client *spanner.Client
Dialect string
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) SpannerClient() *spanner.Client {
return s.Client
}
func (s *Source) DatabaseDialect() string {
return s.Dialect
}
func initSpannerClient(project, instance, dbname string) (*spanner.Client, error) {
// Configure the connection to the database
db := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, dbname)
// Configure session pool to automatically clean inactive transactions
sessionPoolConfig := spanner.SessionPoolConfig{
TrackSessionHandles: true,
InactiveTransactionRemovalOptions: spanner.InactiveTransactionRemovalOptions{
ActionOnInactiveTransaction: spanner.WarnAndClose,
},
}
// Create spanner client
ctx := context.Background()
client, err := spanner.NewClientWithConfig(ctx, db, spanner.ClientConfig{SessionPoolConfig: sessionPoolConfig})
if err != nil {
return nil, fmt.Errorf("unable to create new client: %w", err)
}
defer client.Close()
return client, nil
}

View File

@@ -0,0 +1,148 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spanner_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/testutils"
"gopkg.in/yaml.v3"
)
func TestParseFromYamlSpannerDb(t *testing.T) {
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "google_standard_sql",
Database: "my_db",
},
},
},
{
desc: "gsql dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: Google_standard_sql
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "google_standard_sql",
Database: "my_db",
},
},
},
{
desc: "postgresql dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: postgresql
database: my_db
`,
want: map[string]sources.SourceConfig{
"my-spanner-instance": spanner.Config{
Name: "my-spanner-instance",
Kind: spanner.SourceKind,
Project: "my-project",
Instance: "my-instance",
Dialect: "postgresql",
Database: "my_db",
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func FailParseFromYamlSpanner(t *testing.T) {
tcs := []struct {
desc string
in string
}{
{
desc: "invalid dialect",
in: `
sources:
my-spanner-instance:
kind: spanner
project: my-project
instance: my-instance
dialect: fail
database: my_db
`,
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail: %s", err)
}
})
}
}

View File

@@ -56,6 +56,18 @@ func (p ParamValues) AsMap() map[string]interface{} {
return params
}
// AsMapByOrderedKeys returns a map of a key's position to it's value, as neccesary for Spanner PSQL.
// Example { $1 -> "value1", $2 -> "value2" }
func (p ParamValues) AsMapByOrderedKeys() map[string]interface{} {
params := make(map[string]interface{})
for i, p := range p {
key := fmt.Sprintf("p%d", i+1)
params[key] = p.Value
}
return params
}
// ParseParams parses specified Parameters from data and returns them as ParamValues.
func ParseParams(ps Parameters, data map[string]any) (ParamValues, error) {
params := make([]ParamValue, 0, len(ps))

View File

@@ -256,22 +256,25 @@ func TestParametersParse(t *testing.T) {
func TestParamValues(t *testing.T) {
tcs := []struct {
name string
in tools.ParamValues
wantSlice []any
wantMap map[string]interface{}
name string
in tools.ParamValues
wantSlice []any
wantMap map[string]interface{}
wantMapOrdered map[string]interface{}
}{
{
name: "string",
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
wantSlice: []any{true, "hello world"},
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
name: "string",
in: tools.ParamValues{tools.ParamValue{Name: "my_bool", Value: true}, tools.ParamValue{Name: "my_string", Value: "hello world"}},
wantSlice: []any{true, "hello world"},
wantMap: map[string]interface{}{"my_bool": true, "my_string": "hello world"},
wantMapOrdered: map[string]interface{}{"p1": true, "p2": "hello world"},
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
gotSlice := tc.in.AsSlice()
gotMap := tc.in.AsMap()
gotMapOrdered := tc.in.AsMapByOrderedKeys()
for i, got := range gotSlice {
want := tc.wantSlice[i]
@@ -285,6 +288,12 @@ func TestParamValues(t *testing.T) {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
for i, got := range gotMapOrdered {
want := tc.wantMapOrdered[i]
if got != want {
t.Fatalf("unexpected value: got %q, want %q", got, want)
}
}
})
}
}

View File

@@ -0,0 +1,162 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spanner
import (
"context"
"fmt"
"strings"
"cloud.google.com/go/spanner"
"github.com/googleapis/genai-toolbox/internal/sources"
spannerdb "github.com/googleapis/genai-toolbox/internal/sources/spanner"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/api/iterator"
)
const ToolKind string = "spanner-sql"
type compatibleSource interface {
SpannerClient() *spanner.Client
DatabaseDialect() string
}
// validate compatible sources are still compatible
var _ compatibleSource = &spannerdb.Source{}
var compatibleSources = [...]string{spannerdb.SourceKind}
type Config struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Source string `yaml:"source"`
Description string `yaml:"description"`
Statement string `yaml:"statement"`
Parameters tools.Parameters `yaml:"parameters"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return ToolKind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", ToolKind, compatibleSources)
}
// finish tool setup
t := Tool{
Name: cfg.Name,
Kind: ToolKind,
Parameters: cfg.Parameters,
Statement: cfg.Statement,
Client: s.SpannerClient(),
dialect: s.DatabaseDialect(),
manifest: tools.Manifest{Description: cfg.Description, Parameters: cfg.Parameters.Manifest()},
}
return t, nil
}
func NewGenericTool(name, stmt, desc string, client *spanner.Client, dialect string, parameters tools.Parameters) Tool {
return Tool{
Name: name,
Kind: ToolKind,
Statement: stmt,
Client: client,
dialect: dialect,
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
Parameters: parameters,
}
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
Parameters tools.Parameters `yaml:"parameters"`
Client *spanner.Client
dialect string
Statement string
manifest tools.Manifest
}
func getMapParams(params tools.ParamValues, dialect string) (map[string]interface{}, error) {
switch strings.ToLower(dialect) {
case "google_standard_sql":
return params.AsMap(), nil
case "postgresql":
return params.AsMapByOrderedKeys(), nil
default:
return nil, fmt.Errorf("invalid dialect %s", dialect)
}
}
func (t Tool) Invoke(params tools.ParamValues) (string, error) {
mapParams, err := getMapParams(params, t.dialect)
if err != nil {
return "", fmt.Errorf("fail to get map params: %w", err)
}
fmt.Printf("Invoked tool %s\n", t.Name)
var out strings.Builder
ctx := context.Background()
_, err = t.Client.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
stmt := spanner.Statement{
SQL: t.Statement,
Params: mapParams,
}
iter := txn.Query(ctx, stmt)
defer iter.Stop()
for {
row, err := iter.Next()
if err == iterator.Done {
return nil
}
if err != nil {
return fmt.Errorf("unable to parse row: %w", err)
}
out.WriteString(row.String())
}
})
if err != nil {
return "", fmt.Errorf("unable to execute client: %w", err)
}
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
}
func (t Tool) ParseParams(data map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.Parameters, data)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}

View File

@@ -0,0 +1,79 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package spanner_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/tools/spanner"
"gopkg.in/yaml.v3"
)
func TestParseFromYamlSpanner(t *testing.T) {
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: spanner-sql
source: my-pg-instance
description: some description
statement: |
SELECT * FROM SQL_STATEMENT;
parameters:
- name: country
type: string
description: some description
`,
want: server.ToolConfigs{
"example_tool": spanner.Config{
Name: "example_tool",
Kind: spanner.ToolKind,
Source: "my-pg-instance",
Description: "some description",
Statement: "SELECT * FROM SQL_STATEMENT;\n",
Parameters: []tools.Parameter{
tools.NewStringParameter("country", "some description"),
},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}