mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-04-09 03:02:26 -04:00
chore: refactor sources/tools each into their own package (#42)
Moves all of the "source" and "tool" implementations into their own packages. This layout makes it a bit more clear where the implementations are, and seems likely to scale more cleanly as more sources and tools are added.
This commit is contained in:
12
cmd/root.go
12
cmd/root.go
@@ -22,8 +22,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/spf13/cobra"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -62,7 +60,7 @@ func Execute() {
|
||||
type Command struct {
|
||||
*cobra.Command
|
||||
|
||||
cfg server.Config
|
||||
cfg server.ServerConfig
|
||||
tools_file string
|
||||
}
|
||||
|
||||
@@ -88,11 +86,11 @@ func NewCommand() *Command {
|
||||
}
|
||||
|
||||
// parseToolsFile parses the provided yaml into appropriate configs.
|
||||
func parseToolsFile(raw []byte) (sources.Configs, tools.Configs, tools.ToolsetConfigs, error) {
|
||||
func parseToolsFile(raw []byte) (server.SourceConfigs, server.ToolConfigs, server.ToolsetConfigs, error) {
|
||||
tools_file := &struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Toolsets tools.ToolsetConfigs `yaml:"toolsets"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(raw, tools_file)
|
||||
|
||||
@@ -23,9 +23,10 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
cloudsqlpgtool "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -73,12 +74,12 @@ func TestAddrPort(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
args []string
|
||||
want server.Config
|
||||
want server.ServerConfig
|
||||
}{
|
||||
{
|
||||
desc: "default values",
|
||||
args: []string{},
|
||||
want: server.Config{
|
||||
want: server.ServerConfig{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5000,
|
||||
},
|
||||
@@ -86,7 +87,7 @@ func TestAddrPort(t *testing.T) {
|
||||
{
|
||||
desc: "address short",
|
||||
args: []string{"-a", "127.0.1.1"},
|
||||
want: server.Config{
|
||||
want: server.ServerConfig{
|
||||
Address: "127.0.1.1",
|
||||
Port: 5000,
|
||||
},
|
||||
@@ -94,7 +95,7 @@ func TestAddrPort(t *testing.T) {
|
||||
{
|
||||
desc: "address long",
|
||||
args: []string{"--address", "0.0.0.0"},
|
||||
want: server.Config{
|
||||
want: server.ServerConfig{
|
||||
Address: "0.0.0.0",
|
||||
Port: 5000,
|
||||
},
|
||||
@@ -102,7 +103,7 @@ func TestAddrPort(t *testing.T) {
|
||||
{
|
||||
desc: "port short",
|
||||
args: []string{"-p", "5052"},
|
||||
want: server.Config{
|
||||
want: server.ServerConfig{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5052,
|
||||
},
|
||||
@@ -110,7 +111,7 @@ func TestAddrPort(t *testing.T) {
|
||||
{
|
||||
desc: "port long",
|
||||
args: []string{"--port", "5050"},
|
||||
want: server.Config{
|
||||
want: server.ServerConfig{
|
||||
Address: "127.0.0.1",
|
||||
Port: 5050,
|
||||
},
|
||||
@@ -169,9 +170,9 @@ func TestParseToolFile(t *testing.T) {
|
||||
tcs := []struct {
|
||||
description string
|
||||
in string
|
||||
wantSources sources.Configs
|
||||
wantTools tools.Configs
|
||||
wantToolsets tools.ToolsetConfigs
|
||||
wantSources server.SourceConfigs
|
||||
wantTools server.ToolConfigs
|
||||
wantToolsets server.ToolsetConfigs
|
||||
}{
|
||||
{
|
||||
description: "basic example",
|
||||
@@ -198,20 +199,20 @@ func TestParseToolFile(t *testing.T) {
|
||||
example_toolset:
|
||||
- example_tool
|
||||
`,
|
||||
wantSources: sources.Configs{
|
||||
"my-pg-instance": sources.CloudSQLPgConfig{
|
||||
wantSources: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpgsrc.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: sources.CloudSQLPgKind,
|
||||
Kind: cloudsqlpgsrc.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
Database: "my_db",
|
||||
},
|
||||
},
|
||||
wantTools: tools.Configs{
|
||||
"example_tool": tools.CloudSQLPgGenericConfig{
|
||||
wantTools: server.ToolConfigs{
|
||||
"example_tool": cloudsqlpgtool.GenericConfig{
|
||||
Name: "example_tool",
|
||||
Kind: tools.CloudSQLPgSQLGenericKind,
|
||||
Kind: cloudsqlpgtool.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -220,7 +221,7 @@ func TestParseToolFile(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantToolsets: tools.ToolsetConfigs{
|
||||
wantToolsets: server.ToolsetConfigs{
|
||||
"example_toolset": tools.ToolsetConfig{
|
||||
Name: "example_toolset",
|
||||
ToolNames: []string{"example_tool"},
|
||||
|
||||
@@ -27,12 +27,12 @@ func (t MockTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
return tools.ParseParams(t.Params, data)
|
||||
}
|
||||
|
||||
func (t MockTool) Manifest() tools.ToolManifest {
|
||||
func (t MockTool) Manifest() tools.Manifest {
|
||||
pMs := make([]tools.ParameterManifest, 0, len(t.Params))
|
||||
for _, p := range t.Params {
|
||||
pMs = append(pMs, p.Manifest())
|
||||
}
|
||||
return tools.ToolManifest{Description: t.Description, Parameters: pMs}
|
||||
return tools.Manifest{Description: t.Description, Parameters: pMs}
|
||||
}
|
||||
|
||||
func TestToolsetEndpoint(t *testing.T) {
|
||||
|
||||
@@ -14,11 +14,20 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbpgsrc "github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
cloudsqlpgsrc "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
postgressrc "github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
alloydbpgtool "github.com/googleapis/genai-toolbox/internal/tools/alloydbpg"
|
||||
cloudsqlpgtool "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg"
|
||||
postgrestool "github.com/googleapis/genai-toolbox/internal/tools/postgres"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
type ServerConfig struct {
|
||||
// Server version
|
||||
Version string
|
||||
// Address is the address of the interface the server will listen on.
|
||||
@@ -26,9 +35,126 @@ type Config struct {
|
||||
// Port is the port the server will listen on.
|
||||
Port int
|
||||
// SourceConfigs defines what sources of data are available for tools.
|
||||
SourceConfigs sources.Configs
|
||||
SourceConfigs SourceConfigs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs tools.Configs
|
||||
ToolConfigs ToolConfigs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
ToolsetConfigs tools.ToolsetConfigs
|
||||
ToolsetConfigs ToolsetConfigs
|
||||
}
|
||||
|
||||
// SourceConfigs is a type used to allow unmarshal of the data source config map
|
||||
type SourceConfigs map[string]sources.SourceConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &SourceConfigs{}
|
||||
|
||||
func (c *SourceConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(SourceConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", k)
|
||||
}
|
||||
switch k.Kind {
|
||||
case alloydbpgsrc.SourceKind:
|
||||
actual := alloydbpgsrc.Config{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case cloudsqlpgsrc.SourceKind:
|
||||
actual := cloudsqlpgsrc.Config{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case postgressrc.SourceKind:
|
||||
actual := postgressrc.Config{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.Config
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &ToolConfigs{}
|
||||
|
||||
func (c *ToolConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(ToolConfigs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", name)
|
||||
}
|
||||
switch k.Kind {
|
||||
case alloydbpgtool.ToolKind:
|
||||
actual := alloydbpgtool.GenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case cloudsqlpgtool.ToolKind:
|
||||
actual := cloudsqlpgtool.GenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case postgrestool.ToolKind:
|
||||
actual := postgrestool.GenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ToolsetConfigs map[string]tools.ToolsetConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &ToolsetConfigs{}
|
||||
|
||||
func (c *ToolsetConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(ToolsetConfigs)
|
||||
|
||||
var raw map[string][]string
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, toolList := range raw {
|
||||
(*c)[name] = tools.ToolsetConfig{Name: name, ToolNames: toolList}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ import (
|
||||
|
||||
// Server contains info for running an instance of Toolbox. Should be instantiated with NewServer().
|
||||
type Server struct {
|
||||
conf Config
|
||||
conf ServerConfig
|
||||
root chi.Router
|
||||
|
||||
sources map[string]sources.Source
|
||||
@@ -39,7 +39,7 @@ type Server struct {
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
func NewServer(cfg Config) (*Server, error) {
|
||||
func NewServer(cfg ServerConfig) (*Server, error) {
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(middleware.Recoverer)
|
||||
@@ -76,7 +76,7 @@ func NewServer(cfg Config) (*Server, error) {
|
||||
allToolNames = append(allToolNames, name)
|
||||
}
|
||||
if cfg.ToolsetConfigs == nil {
|
||||
cfg.ToolsetConfigs = make(tools.ToolsetConfigs)
|
||||
cfg.ToolsetConfigs = make(ToolsetConfigs)
|
||||
}
|
||||
cfg.ToolsetConfigs[""] = tools.ToolsetConfig{Name: "", ToolNames: allToolNames}
|
||||
// initalize and validate the toolsets
|
||||
|
||||
@@ -43,7 +43,7 @@ func TestServe(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
addr, port := "127.0.0.1", 5000
|
||||
cfg := server.Config{
|
||||
cfg := server.ServerConfig{
|
||||
Address: addr,
|
||||
Port: port,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
package alloydbpg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -20,15 +20,16 @@ import (
|
||||
"net"
|
||||
|
||||
"cloud.google.com/go/alloydbconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const AlloyDBPgKind string = "alloydb-postgres"
|
||||
const SourceKind string = "alloydb-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ Config = AlloyDBPgConfig{}
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type AlloyDBPgConfig struct {
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
@@ -40,11 +41,11 @@ type AlloyDBPgConfig struct {
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r AlloyDBPgConfig) sourceKind() string {
|
||||
return AlloyDBPgKind
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r AlloyDBPgConfig) Initialize() (Source, error) {
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
pool, err := initAlloyDBPgConnectionPool(r.Project, r.Region, r.Cluster, r.Instance, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
@@ -55,22 +56,26 @@ func (r AlloyDBPgConfig) Initialize() (Source, error) {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := AlloyDBPgSource{
|
||||
s := Source{
|
||||
Name: r.Name,
|
||||
Kind: AlloyDBPgKind,
|
||||
Kind: SourceKind,
|
||||
Pool: pool,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ Source = AlloyDBPgSource{}
|
||||
var _ sources.Source = Source{}
|
||||
|
||||
type AlloyDBPgSource struct {
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func initAlloyDBPgConnectionPool(project, region, cluster, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
@@ -12,13 +12,15 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources_test
|
||||
package alloydbpg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -27,7 +29,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want sources.Configs
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -41,10 +43,10 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
`,
|
||||
want: sources.Configs{
|
||||
"my-pg-instance": sources.AlloyDBPgConfig{
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-pg-instance": alloydbpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: sources.AlloyDBPgKind,
|
||||
Kind: alloydbpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Cluster: "my-cluster",
|
||||
@@ -57,7 +59,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
package cloudsqlpg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -20,15 +20,16 @@ import (
|
||||
"net"
|
||||
|
||||
"cloud.google.com/go/cloudsqlconn"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const CloudSQLPgKind string = "cloud-sql-postgres"
|
||||
const SourceKind string = "cloud-sql-postgres"
|
||||
|
||||
// validate interface
|
||||
var _ Config = CloudSQLPgConfig{}
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type CloudSQLPgConfig struct {
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string `yaml:"project"`
|
||||
@@ -39,11 +40,11 @@ type CloudSQLPgConfig struct {
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r CloudSQLPgConfig) sourceKind() string {
|
||||
return CloudSQLPgKind
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r CloudSQLPgConfig) Initialize() (Source, error) {
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
pool, err := initCloudSQLPgConnectionPool(r.Project, r.Region, r.Instance, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create pool: %w", err)
|
||||
@@ -54,22 +55,26 @@ func (r CloudSQLPgConfig) Initialize() (Source, error) {
|
||||
return nil, fmt.Errorf("unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := CloudSQLPgSource{
|
||||
s := Source{
|
||||
Name: r.Name,
|
||||
Kind: CloudSQLPgKind,
|
||||
Kind: SourceKind,
|
||||
Pool: pool,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ Source = CloudSQLPgSource{}
|
||||
var _ sources.Source = Source{}
|
||||
|
||||
type CloudSQLPgSource struct {
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func initCloudSQLPgConnectionPool(project, region, instance, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// Configure the driver to connect to the database
|
||||
dsn := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=disable", user, pass, dbname)
|
||||
@@ -12,13 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources_test
|
||||
package cloudsqlpg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -27,7 +28,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want sources.Configs
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -40,10 +41,10 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
instance: my-instance
|
||||
database: my_db
|
||||
`,
|
||||
want: sources.Configs{
|
||||
"my-pg-instance": sources.CloudSQLPgConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": cloudsqlpg.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: sources.CloudSQLPgKind,
|
||||
Kind: cloudsqlpg.SourceKind,
|
||||
Project: "my-project",
|
||||
Region: "my-region",
|
||||
Instance: "my-instance",
|
||||
@@ -55,7 +56,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -12,21 +12,22 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const PostgresKind string = "postgres"
|
||||
const SourceKind string = "postgres"
|
||||
|
||||
// validate interface
|
||||
var _ Config = PostgresConfig{}
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
type PostgresConfig struct {
|
||||
type Config struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Host string `yaml:"host"`
|
||||
@@ -36,11 +37,11 @@ type PostgresConfig struct {
|
||||
Database string `yaml:"database"`
|
||||
}
|
||||
|
||||
func (r PostgresConfig) sourceKind() string {
|
||||
return PostgresKind
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (r PostgresConfig) Initialize() (Source, error) {
|
||||
func (r Config) Initialize() (sources.Source, error) {
|
||||
pool, err := initPostgresConnectionPool(r.Host, r.Port, r.User, r.Password, r.Database)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to create pool: %w", err)
|
||||
@@ -51,22 +52,26 @@ func (r PostgresConfig) Initialize() (Source, error) {
|
||||
return nil, fmt.Errorf("Unable to connect successfully: %w", err)
|
||||
}
|
||||
|
||||
s := PostgresSource{
|
||||
s := Source{
|
||||
Name: r.Name,
|
||||
Kind: PostgresKind,
|
||||
Kind: SourceKind,
|
||||
Pool: pool,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ Source = PostgresSource{}
|
||||
var _ sources.Source = Source{}
|
||||
|
||||
type PostgresSource struct {
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (s Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func initPostgresConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) {
|
||||
// urlExample := "postgres:dd//username:password@localhost:5432/database_name"
|
||||
i := fmt.Sprintf("postgres://%s:%s@%s:%s/%s", user, pass, host, port, dbname)
|
||||
@@ -12,13 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sources_test
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -27,7 +28,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want sources.Configs
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -39,10 +40,10 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
port: 0000
|
||||
database: my_db
|
||||
`,
|
||||
want: sources.Configs{
|
||||
"my-pg-instance": sources.PostgresConfig{
|
||||
want: server.SourceConfigs{
|
||||
"my-pg-instance": postgres.Config{
|
||||
Name: "my-pg-instance",
|
||||
Kind: sources.PostgresKind,
|
||||
Kind: postgres.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: "0000",
|
||||
Database: "my_db",
|
||||
@@ -53,7 +54,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources sources.Configs `yaml:"sources"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -14,65 +14,13 @@
|
||||
|
||||
package sources
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
sourceKind() string
|
||||
// SourceConfig is the interface for configuring a source.
|
||||
type SourceConfig interface {
|
||||
SourceConfigKind() string
|
||||
Initialize() (Source, error)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &Configs{}
|
||||
|
||||
// Configs is a type used to allow unmarshal of the data source config map
|
||||
type Configs map[string]Config
|
||||
|
||||
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(Configs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", k)
|
||||
}
|
||||
switch k.Kind {
|
||||
case CloudSQLPgKind:
|
||||
actual := CloudSQLPgConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case AlloyDBPgKind:
|
||||
actual := AlloyDBPgConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case PostgresKind:
|
||||
actual := PostgresConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of data source", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Source is the interface for the source itself.
|
||||
type Source interface {
|
||||
SourceKind() string
|
||||
}
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
const AlloyDBPgSQLGenericKind string = "alloydb-postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ Config = AlloyDBPgGenericConfig{}
|
||||
|
||||
type AlloyDBPgGenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg AlloyDBPgGenericConfig) toolKind() string {
|
||||
return AlloyDBPgSQLGenericKind
|
||||
}
|
||||
|
||||
func (cfg AlloyDBPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(sources.AlloyDBPgSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", AlloyDBPgSQLGenericKind, sources.AlloyDBPgKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := AlloyDBPgGenericTool{
|
||||
PostgresGenericTool: PostgresGenericTool{
|
||||
Name: cfg.Name,
|
||||
Kind: AlloyDBPgSQLGenericKind,
|
||||
Pool: s.Pool,
|
||||
Statement: cfg.Statement,
|
||||
Parameters: cfg.Parameters,
|
||||
manifest: ToolManifest{cfg.Description, generateManifests(cfg.Parameters)},
|
||||
},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ Tool = AlloyDBPgGenericTool{}
|
||||
|
||||
type AlloyDBPgGenericTool struct {
|
||||
PostgresGenericTool
|
||||
}
|
||||
69
internal/tools/alloydbpg/alloydb_pg.go
Normal file
69
internal/tools/alloydbpg/alloydb_pg.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package alloydbpg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres"
|
||||
)
|
||||
|
||||
const ToolKind string = "alloydb-postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ tools.Config = GenericConfig{}
|
||||
|
||||
type GenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) ToolKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(alloydbpg.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", ToolKind, alloydbpg.SourceKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := GenericTool{
|
||||
GenericTool: postgres.NewGenericTool(cfg.Name, cfg.Statement, cfg.Description, s.Pool, cfg.Parameters),
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = GenericTool{}
|
||||
|
||||
type GenericTool struct {
|
||||
postgres.GenericTool
|
||||
}
|
||||
@@ -12,14 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools_test
|
||||
package alloydbpg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/alloydbpg"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -27,7 +29,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want tools.Configs
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -44,10 +46,10 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
want: tools.Configs{
|
||||
"example_tool": tools.AlloyDBPgGenericConfig{
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": alloydbpg.GenericConfig{
|
||||
Name: "example_tool",
|
||||
Kind: tools.AlloyDBPgSQLGenericKind,
|
||||
Kind: alloydbpg.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -61,7 +63,7 @@ func TestParseFromYamlAlloyDBPg(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
)
|
||||
|
||||
const CloudSQLPgSQLGenericKind string = "cloud-sql-postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ Config = CloudSQLPgGenericConfig{}
|
||||
|
||||
type CloudSQLPgGenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg CloudSQLPgGenericConfig) toolKind() string {
|
||||
return CloudSQLPgSQLGenericKind
|
||||
}
|
||||
|
||||
func (cfg CloudSQLPgGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(sources.CloudSQLPgSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", CloudSQLPgSQLGenericKind, sources.CloudSQLPgKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := CloudSQLPgGenericTool{
|
||||
PostgresGenericTool: PostgresGenericTool{
|
||||
Name: cfg.Name,
|
||||
Kind: CloudSQLPgSQLGenericKind,
|
||||
Pool: s.Pool,
|
||||
Statement: cfg.Statement,
|
||||
Parameters: cfg.Parameters,
|
||||
manifest: ToolManifest{cfg.Description, generateManifests(cfg.Parameters)},
|
||||
},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ Tool = CloudSQLPgGenericTool{}
|
||||
|
||||
type CloudSQLPgGenericTool struct {
|
||||
PostgresGenericTool
|
||||
}
|
||||
69
internal/tools/cloudsqlpg/cloud_sql_pg.go
Normal file
69
internal/tools/cloudsqlpg/cloud_sql_pg.go
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlpg
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres"
|
||||
)
|
||||
|
||||
const ToolKind string = "cloud-sql-postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ tools.Config = GenericConfig{}
|
||||
|
||||
type GenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) ToolKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(cloudsqlpg.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", ToolKind, cloudsqlpg.SourceKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := GenericTool{
|
||||
GenericTool: postgres.NewGenericTool(cfg.Name, cfg.Statement, cfg.Description, s.Pool, cfg.Parameters),
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = GenericTool{}
|
||||
|
||||
type GenericTool struct {
|
||||
postgres.GenericTool
|
||||
}
|
||||
@@ -12,14 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools_test
|
||||
package cloudsqlpg_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -27,7 +29,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want tools.Configs
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -44,10 +46,10 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
want: tools.Configs{
|
||||
"example_tool": tools.CloudSQLPgGenericConfig{
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": cloudsqlpg.GenericConfig{
|
||||
Name: "example_tool",
|
||||
Kind: tools.CloudSQLPgSQLGenericKind,
|
||||
Kind: cloudsqlpg.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -61,7 +63,7 @@ func TestParseFromYamlCloudSQLPg(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -115,7 +115,7 @@ func parseFromYamlNode(node *yaml.Node) (Parameter, error) {
|
||||
return nil, fmt.Errorf("%q is not valid type for a parameter!", p.Type)
|
||||
}
|
||||
|
||||
func generateManifests(ps []Parameter) []ParameterManifest {
|
||||
func (ps Parameters) Manifest() []ParameterManifest {
|
||||
rtn := make([]ParameterManifest, 0, len(ps))
|
||||
for _, p := range ps {
|
||||
rtn = append(rtn, p.Manifest())
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const PostgresSQLGenericKind string = "postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ Config = PostgresGenericConfig{}
|
||||
|
||||
type PostgresGenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg PostgresGenericConfig) toolKind() string {
|
||||
return PostgresSQLGenericKind
|
||||
}
|
||||
|
||||
func (cfg PostgresGenericConfig) Initialize(srcs map[string]sources.Source) (Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(sources.PostgresSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", PostgresSQLGenericKind, sources.PostgresKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := PostgresGenericTool{
|
||||
Name: cfg.Name,
|
||||
Kind: PostgresSQLGenericKind,
|
||||
Pool: s.Pool,
|
||||
Statement: cfg.Statement,
|
||||
Parameters: cfg.Parameters,
|
||||
manifest: ToolManifest{cfg.Description, generateManifests(cfg.Parameters)},
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ Tool = PostgresGenericTool{}
|
||||
|
||||
type PostgresGenericTool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
Parameters Parameters `yaml:"parameters"`
|
||||
manifest ToolManifest
|
||||
}
|
||||
|
||||
func (t PostgresGenericTool) Invoke(params []any) (string, error) {
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
results, err := t.Pool.Query(context.Background(), t.Statement, params...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("%s", v))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
||||
}
|
||||
|
||||
func (t PostgresGenericTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
return ParseParams(t.Parameters, data)
|
||||
}
|
||||
|
||||
func (t PostgresGenericTool) Manifest() ToolManifest {
|
||||
return t.manifest
|
||||
}
|
||||
113
internal/tools/postgres/generic.go
Normal file
113
internal/tools/postgres/generic.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const ToolKind string = "postgres-generic"
|
||||
|
||||
// validate interface
|
||||
var _ tools.Config = GenericConfig{}
|
||||
|
||||
type GenericConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Source string `yaml:"source"`
|
||||
Description string `yaml:"description"`
|
||||
Statement string `yaml:"statement"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) ToolKind() string {
|
||||
return ToolKind
|
||||
}
|
||||
|
||||
func (cfg GenericConfig) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is the right kind
|
||||
s, ok := rawS.(postgres.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sources for %q tools must be of kind %q", ToolKind, postgres.SourceKind)
|
||||
}
|
||||
|
||||
// finish tool setup
|
||||
t := NewGenericTool(cfg.Name, cfg.Statement, cfg.Description, s.Pool, cfg.Parameters)
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func NewGenericTool(name, stmt, desc string, pool *pgxpool.Pool, parameters tools.Parameters) GenericTool {
|
||||
return GenericTool{
|
||||
Name: name,
|
||||
Kind: ToolKind,
|
||||
Statement: stmt,
|
||||
Pool: pool,
|
||||
manifest: tools.Manifest{Description: desc, Parameters: parameters.Manifest()},
|
||||
Parameters: parameters,
|
||||
}
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = GenericTool{}
|
||||
|
||||
type GenericTool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Parameters tools.Parameters `yaml:"parameters"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
Statement string
|
||||
manifest tools.Manifest
|
||||
}
|
||||
|
||||
func (t GenericTool) Invoke(params []any) (string, error) {
|
||||
fmt.Printf("Invoked tool %s\n", t.Name)
|
||||
results, err := t.Pool.Query(context.Background(), t.Statement, params...)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
var out strings.Builder
|
||||
for results.Next() {
|
||||
v, err := results.Values()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
out.WriteString(fmt.Sprintf("%s", v))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Stub tool call for %q! Parameters parsed: %q \n Output: %s", t.Name, params, out.String()), nil
|
||||
}
|
||||
|
||||
func (t GenericTool) ParseParams(data map[string]any) ([]any, error) {
|
||||
return tools.ParseParams(t.Parameters, data)
|
||||
}
|
||||
|
||||
func (t GenericTool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
@@ -12,14 +12,16 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tools_test
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -27,7 +29,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want tools.Configs
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
@@ -44,10 +46,10 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
type: string
|
||||
description: some description
|
||||
`,
|
||||
want: tools.Configs{
|
||||
"example_tool": tools.PostgresGenericConfig{
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgres.GenericConfig{
|
||||
Name: "example_tool",
|
||||
Kind: tools.PostgresSQLGenericKind,
|
||||
Kind: postgres.ToolKind,
|
||||
Source: "my-pg-instance",
|
||||
Description: "some description",
|
||||
Statement: "SELECT * FROM SQL_STATEMENT;\n",
|
||||
@@ -61,7 +63,7 @@ func TestParseFromYamlPostgres(t *testing.T) {
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools tools.Configs `yaml:"tools"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
@@ -15,73 +15,22 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
toolKind() string
|
||||
ToolKind() string
|
||||
Initialize(map[string]sources.Source) (Tool, error)
|
||||
}
|
||||
|
||||
// SourceConfigs is a type used to allow unmarshal of the data source config map
|
||||
type Configs map[string]Config
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &Configs{}
|
||||
|
||||
func (c *Configs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(Configs)
|
||||
// Parse the 'kind' fields for each source
|
||||
var raw map[string]yaml.Node
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, n := range raw {
|
||||
var k struct {
|
||||
Kind string `yaml:"kind"`
|
||||
}
|
||||
err := n.Decode(&k)
|
||||
if err != nil {
|
||||
return fmt.Errorf("missing 'kind' field for %q", name)
|
||||
}
|
||||
switch k.Kind {
|
||||
case CloudSQLPgSQLGenericKind:
|
||||
actual := CloudSQLPgGenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case AlloyDBPgSQLGenericKind:
|
||||
actual := AlloyDBPgGenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
case PostgresSQLGenericKind:
|
||||
actual := PostgresGenericConfig{Name: name}
|
||||
if err := n.Decode(&actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", k.Kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of tool", k.Kind)
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Tool interface {
|
||||
Invoke([]any) (string, error)
|
||||
ParseParams(data map[string]any) ([]any, error)
|
||||
Manifest() ToolManifest
|
||||
Manifest() Manifest
|
||||
}
|
||||
|
||||
type ToolManifest struct {
|
||||
// Manifest is the representation of tools sent to Client SDKs.
|
||||
type Manifest struct {
|
||||
Description string `json:"description"`
|
||||
Parameters []ParameterManifest `json:"parameters"`
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@ package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type ToolsetConfig struct {
|
||||
@@ -25,8 +23,6 @@ type ToolsetConfig struct {
|
||||
ToolNames []string `yaml:",inline"`
|
||||
}
|
||||
|
||||
type ToolsetConfigs map[string]ToolsetConfig
|
||||
|
||||
type Toolset struct {
|
||||
Name string `yaml:"name"`
|
||||
Tools []*Tool `yaml:",inline"`
|
||||
@@ -34,25 +30,8 @@ type Toolset struct {
|
||||
}
|
||||
|
||||
type ToolsetManifest struct {
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
ToolsManifest map[string]ToolManifest `json:"tools"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ yaml.Unmarshaler = &ToolsetConfigs{}
|
||||
|
||||
func (c *ToolsetConfigs) UnmarshalYAML(node *yaml.Node) error {
|
||||
*c = make(ToolsetConfigs)
|
||||
|
||||
var raw map[string][]string
|
||||
if err := node.Decode(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, toolList := range raw {
|
||||
(*c)[name] = ToolsetConfig{Name: name, ToolNames: toolList}
|
||||
}
|
||||
return nil
|
||||
ServerVersion string `json:"serverVersion"`
|
||||
ToolsManifest map[string]Manifest `json:"tools"`
|
||||
}
|
||||
|
||||
func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool) (Toolset, error) {
|
||||
@@ -66,7 +45,7 @@ func (t ToolsetConfig) Initialize(serverVersion string, toolsMap map[string]Tool
|
||||
toolset.Tools = make([]*Tool, len(t.ToolNames))
|
||||
toolset.Manifest = ToolsetManifest{
|
||||
ServerVersion: serverVersion,
|
||||
ToolsManifest: make(map[string]ToolManifest),
|
||||
ToolsManifest: make(map[string]Manifest),
|
||||
}
|
||||
for _, toolName := range t.ToolNames {
|
||||
tool, ok := toolsMap[toolName]
|
||||
|
||||
Reference in New Issue
Block a user