mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
style: run linter (#1518)
This commit is contained in:
@@ -40,8 +40,8 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
@@ -56,8 +56,8 @@ func TestParseFromYamlAlloyDBAdmin(t *testing.T) {
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-alloydb-admin-instance": alloydbadmin.Config{
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
Name: "my-alloydb-admin-instance",
|
||||
Kind: alloydbadmin.SourceKind,
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -145,17 +145,17 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
// BigQuery Google SQL struct with client
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string
|
||||
Location string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
AllowedDatasets map[string]struct{}
|
||||
UseClientOAuth bool
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Project string
|
||||
Location string
|
||||
Client *bigqueryapi.Client
|
||||
RestService *bigqueryrestapi.Service
|
||||
TokenSource oauth2.TokenSource
|
||||
MaxQueryResultRows int
|
||||
ClientCreator BigqueryClientCreator
|
||||
AllowedDatasets map[string]struct{}
|
||||
UseClientOAuth bool
|
||||
makeDataplexCatalogClient func() (*dataplexapi.CatalogClient, DataplexClientCreator, error)
|
||||
}
|
||||
|
||||
@@ -405,4 +405,4 @@ func newDataplexClientCreator(
|
||||
return func(tokenString string) (*dataplexapi.CatalogClient, error) {
|
||||
return initDataplexConnectionWithOAuthToken(ctx, project, userAgent, tokenString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ var _ sources.Source = &Source{}
|
||||
type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
BaseURL string `yaml:"baseUrl"`
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
UseClientOAuth bool
|
||||
|
||||
@@ -15,39 +15,39 @@
|
||||
package alloydbcreatecluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-create-cluster"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// Configuration for the create-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -55,148 +55,148 @@ var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the kind of the tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The GCP project ID."),
|
||||
tools.NewStringParameterWithDefault("location", "us-central1", "The location to create the cluster in. The default value is us-central1. If quota is exhausted then use other regions."),
|
||||
tools.NewStringParameter("cluster", "A unique ID for the AlloyDB cluster."),
|
||||
tools.NewStringParameter("password", "A secure password for the initial user."),
|
||||
tools.NewStringParameterWithDefault("network", "default", "The name of the VPC network to connect the cluster to (e.g., 'default')."),
|
||||
tools.NewStringParameterWithDefault("user", "postgres", "The name for the initial superuser. Defaults to 'postgres' if not provided."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The GCP project ID."),
|
||||
tools.NewStringParameterWithDefault("location", "us-central1", "The location to create the cluster in. The default value is us-central1. If quota is exhausted then use other regions."),
|
||||
tools.NewStringParameter("cluster", "A unique ID for the AlloyDB cluster."),
|
||||
tools.NewStringParameter("password", "A secure password for the initial user."),
|
||||
tools.NewStringParameterWithDefault("network", "default", "The name of the VPC network to connect the cluster to (e.g., 'default')."),
|
||||
tools.NewStringParameterWithDefault("user", "postgres", "The name for the initial superuser. Defaults to 'postgres' if not provided."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"project", "cluster", "password"}
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"project", "cluster", "password"}
|
||||
|
||||
description := cfg.Description
|
||||
description := cfg.Description
|
||||
if description == "" {
|
||||
description = "Creates a new AlloyDB cluster. This is a long-running operation, but the API call returns quickly. This will return operation id to be used by get operations tool. Take all parameters from user in one go."
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the create-cluster tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
}
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
|
||||
clusterID, ok := paramsMap["cluster"].(string)
|
||||
if !ok || clusterID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
}
|
||||
clusterID, ok := paramsMap["cluster"].(string)
|
||||
if !ok || clusterID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
password, ok := paramsMap["password"].(string)
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'password' parameter; expected a non-empty string")
|
||||
}
|
||||
password, ok := paramsMap["password"].(string)
|
||||
if !ok || password == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'password' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
network, ok := paramsMap["network"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'network' parameter; expected a string")
|
||||
}
|
||||
network, ok := paramsMap["network"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'network' parameter; expected a string")
|
||||
}
|
||||
|
||||
user, ok := paramsMap["user"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
user, ok := paramsMap["user"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
|
||||
}
|
||||
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
|
||||
|
||||
// Build the request body using the type-safe Cluster struct.
|
||||
clusterBody := &alloydb.Cluster{
|
||||
NetworkConfig: &alloydb.NetworkConfig{
|
||||
Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network),
|
||||
},
|
||||
InitialUser: &alloydb.UserPassword{
|
||||
User: user,
|
||||
Password: password,
|
||||
},
|
||||
}
|
||||
// Build the request body using the type-safe Cluster struct.
|
||||
clusterBody := &alloydb.Cluster{
|
||||
NetworkConfig: &alloydb.NetworkConfig{
|
||||
Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network),
|
||||
},
|
||||
InitialUser: &alloydb.UserPassword{
|
||||
User: user,
|
||||
Password: password,
|
||||
},
|
||||
}
|
||||
|
||||
// The Create API returns a long-running operation.
|
||||
resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(clusterID).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err)
|
||||
}
|
||||
// The Create API returns a long-running operation.
|
||||
resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(clusterID).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
// Manifest returns the tool's manifest.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the tool's MCP manifest.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
return t.Source.UseClientAuthorization()
|
||||
}
|
||||
|
||||
@@ -15,47 +15,47 @@
|
||||
package alloydbcreatecluster_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
alloydbcreatecluster "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
alloydbcreatecluster "github.com/googleapis/genai-toolbox/internal/tools/alloydb/alloydbcreatecluster"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
create-my-cluster:
|
||||
kind: alloydb-create-cluster
|
||||
source: my-alloydb-admin-source
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-cluster": alloydbcreatecluster.Config{
|
||||
Name: "create-my-cluster",
|
||||
Kind: "alloydb-create-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
want: server.ToolConfigs{
|
||||
"create-my-cluster": alloydbcreatecluster.Config{
|
||||
Name: "create-my-cluster",
|
||||
Kind: "alloydb-create-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
tools:
|
||||
create-my-cluster-auth:
|
||||
kind: alloydb-create-cluster
|
||||
@@ -65,30 +65,30 @@ func TestParseFromYaml(t *testing.T) {
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-cluster-auth": alloydbcreatecluster.Config{
|
||||
Name: "create-my-cluster-auth",
|
||||
Kind: "alloydb-create-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
want: server.ToolConfigs{
|
||||
"create-my-cluster-auth": alloydbcreatecluster.Config{
|
||||
Name: "create-my-cluster-auth",
|
||||
Kind: "alloydb-create-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ package alloydbcreateinstance
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-create-instance"
|
||||
@@ -43,11 +43,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the create-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -61,14 +61,14 @@ func (cfg Config) ToolConfigKind() string {
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The GCP project ID."),
|
||||
@@ -96,50 +96,50 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the create-instance tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
}
|
||||
if !ok || project == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok || location == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'location' parameter; expected a non-empty string")
|
||||
}
|
||||
if !ok || location == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'location' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok || cluster == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok || cluster == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'cluster' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
instanceID, ok := paramsMap["instance"].(string)
|
||||
if !ok || instanceID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'instance' parameter; expected a non-empty string")
|
||||
}
|
||||
instanceID, ok := paramsMap["instance"].(string)
|
||||
if !ok || instanceID == "" {
|
||||
return nil, fmt.Errorf("invalid or missing 'instance' parameter; expected a non-empty string")
|
||||
}
|
||||
|
||||
instanceType, ok := paramsMap["instanceType"].(string)
|
||||
if !ok || (instanceType != "READ_POOL" && instanceType != "PRIMARY") {
|
||||
@@ -169,7 +169,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
}
|
||||
|
||||
if instanceType == "READ_POOL" {
|
||||
nodeCount, ok := paramsMap["nodeCount"].(int)
|
||||
nodeCount, ok := paramsMap["nodeCount"].(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL")
|
||||
}
|
||||
@@ -208,5 +208,5 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
return t.Source.UseClientAuthorization()
|
||||
}
|
||||
|
||||
@@ -25,37 +25,37 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
create-my-instance:
|
||||
kind: alloydb-create-instance
|
||||
source: my-alloydb-admin-source
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-instance": alloydbcreateinstance.Config{
|
||||
Name: "create-my-instance",
|
||||
Kind: "alloydb-create-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
want: server.ToolConfigs{
|
||||
"create-my-instance": alloydbcreateinstance.Config{
|
||||
Name: "create-my-instance",
|
||||
Kind: "alloydb-create-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
tools:
|
||||
create-my-instance-auth:
|
||||
kind: alloydb-create-instance
|
||||
@@ -65,30 +65,30 @@ func TestParseFromYaml(t *testing.T) {
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-instance-auth": alloydbcreateinstance.Config{
|
||||
Name: "create-my-instance-auth",
|
||||
Kind: "alloydb-create-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
want: server.ToolConfigs{
|
||||
"create-my-instance-auth": alloydbcreateinstance.Config{
|
||||
Name: "create-my-instance-auth",
|
||||
Kind: "alloydb-create-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ package alloydbcreateuser
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-create-user"
|
||||
@@ -43,11 +43,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the create-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -61,14 +61,14 @@ func (cfg Config) ToolConfigKind() string {
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("source %q not found", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
s, ok := rawS.(*alloydbadmin.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `alloydb-admin`", kind)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The GCP project ID."),
|
||||
@@ -96,26 +96,26 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the create-user tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
|
||||
@@ -25,37 +25,37 @@ import (
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
create-my-user:
|
||||
kind: alloydb-create-user
|
||||
source: my-alloydb-admin-source
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-user": alloydbcreateuser.Config{
|
||||
Name: "create-my-user",
|
||||
Kind: "alloydb-create-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
want: server.ToolConfigs{
|
||||
"create-my-user": alloydbcreateuser.Config{
|
||||
Name: "create-my-user",
|
||||
Kind: "alloydb-create-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "with auth required",
|
||||
in: `
|
||||
tools:
|
||||
create-my-user-auth:
|
||||
kind: alloydb-create-user
|
||||
@@ -65,30 +65,30 @@ func TestParseFromYaml(t *testing.T) {
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"create-my-user-auth": alloydbcreateuser.Config{
|
||||
Name: "create-my-user-auth",
|
||||
Kind: "alloydb-create-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
want: server.ToolConfigs{
|
||||
"create-my-user-auth": alloydbcreateuser.Config{
|
||||
Name: "create-my-user-auth",
|
||||
Kind: "alloydb-create-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the get-cluster tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -92,12 +92,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-cluster": alloydbgetcluster.Config{
|
||||
Name: "get-my-cluster",
|
||||
Kind: "alloydb-get-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-cluster",
|
||||
Kind: "alloydb-get-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -67,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-cluster-auth": alloydbgetcluster.Config{
|
||||
Name: "get-my-cluster-auth",
|
||||
Kind: "alloydb-get-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-cluster-auth",
|
||||
Kind: "alloydb-get-cluster",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
@@ -91,4 +91,4 @@ func TestParseFromYaml(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the get-instance tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -93,12 +93,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-instance": alloydbgetinstance.Config{
|
||||
Name: "get-my-instance",
|
||||
Kind: "alloydb-get-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-instance",
|
||||
Kind: "alloydb-get-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -67,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-instance-auth": alloydbgetinstance.Config{
|
||||
Name: "get-my-instance-auth",
|
||||
Kind: "alloydb-get-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-instance-auth",
|
||||
Kind: "alloydb-get-instance",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -43,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the get-user tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -94,12 +93,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -46,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-user": alloydbgetuser.Config{
|
||||
Name: "get-my-user",
|
||||
Kind: "alloydb-get-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-user",
|
||||
Kind: "alloydb-get-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -68,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get-my-user-auth": alloydbgetuser.Config{
|
||||
Name: "get-my-user-auth",
|
||||
Kind: "alloydb-get-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "get-my-user-auth",
|
||||
Kind: "alloydb-get-user",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -42,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the list-clusters tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -91,21 +91,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the list-clusters tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
@@ -122,7 +122,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-clusters": alloydblistclusters.Config{
|
||||
Name: "list-my-clusters",
|
||||
Kind: "alloydb-list-clusters",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-clusters",
|
||||
Kind: "alloydb-list-clusters",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -67,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-clusters-auth": alloydblistclusters.Config{
|
||||
Name: "list-my-clusters-auth",
|
||||
Kind: "alloydb-list-clusters",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-clusters-auth",
|
||||
Kind: "alloydb-list-clusters",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
@@ -91,4 +91,4 @@ func TestParseFromYaml(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the list-instances tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -92,21 +92,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the list-instances tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
@@ -123,11 +123,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-instances": alloydblistinstances.Config{
|
||||
Name: "list-my-instances",
|
||||
Kind: "alloydb-list-instances",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-instances",
|
||||
Kind: "alloydb-list-instances",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -67,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-instances-auth": alloydblistinstances.Config{
|
||||
Name: "list-my-instances-auth",
|
||||
Kind: "alloydb-list-instances",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-instances-auth",
|
||||
Kind: "alloydb-list-instances",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
@@ -91,4 +91,4 @@ func TestParseFromYaml(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,12 +42,12 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
|
||||
// Configuration for the list-users tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -92,21 +92,21 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Source: s,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the list-users tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
|
||||
Source *alloydbadmin.Source
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
@@ -123,11 +123,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid or missing 'project' parameter; expected a string")
|
||||
}
|
||||
location, ok := paramsMap["location"].(string)
|
||||
if !ok {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
|
||||
}
|
||||
cluster, ok := paramsMap["cluster"].(string)
|
||||
if !ok {
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
|
||||
}
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-users": alloydblistusers.Config{
|
||||
Name: "list-my-users",
|
||||
Kind: "alloydb-list-users",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-users",
|
||||
Kind: "alloydb-list-users",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
@@ -67,10 +67,10 @@ func TestParseFromYaml(t *testing.T) {
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"list-my-users-auth": alloydblistusers.Config{
|
||||
Name: "list-my-users-auth",
|
||||
Kind: "alloydb-list-users",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
Name: "list-my-users-auth",
|
||||
Kind: "alloydb-list-users",
|
||||
Source: "my-alloydb-admin-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
@@ -91,4 +91,4 @@ func TestParseFromYaml(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
if msg, ok := t.generateAlloyDBConnectionMessage(map[string]any{"response": op.Response}); ok {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
|
||||
return string(opBytes), nil
|
||||
}
|
||||
fmt.Printf("Operation not complete, retrying in %v\n", delay)
|
||||
|
||||
@@ -247,7 +247,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
getInsightsSQL := fmt.Sprintf("SELECT * FROM ML.GET_INSIGHTS(MODEL %s)", modelID)
|
||||
|
||||
getInsightsQuery := bqClient.Query(getInsightsSQL)
|
||||
getInsightsQuery.QueryConfig.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
getInsightsQuery.ConnectionProperties = []*bigqueryapi.ConnectionProperty{
|
||||
{Key: "session_id", Value: sessionID},
|
||||
}
|
||||
|
||||
|
||||
@@ -99,18 +99,18 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: description,
|
||||
Description: description,
|
||||
InputSchema: parameters.McpManifest(),
|
||||
}
|
||||
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Parameters: parameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
UseClientOAuth: s.UseClientAuthorization(),
|
||||
MakeCatalogClient: makeCatalogClient,
|
||||
ProjectID: s.BigQueryProject(),
|
||||
ProjectID: s.BigQueryProject(),
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: parameters.Manifest(),
|
||||
@@ -122,15 +122,15 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Name string
|
||||
Kind string
|
||||
Parameters tools.Parameters
|
||||
AuthRequired []string
|
||||
UseClientOAuth bool
|
||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
Name string
|
||||
Kind string
|
||||
Parameters tools.Parameters
|
||||
AuthRequired []string
|
||||
UseClientOAuth bool
|
||||
MakeCatalogClient func() (*dataplexapi.CatalogClient, bigqueryds.DataplexClientCreator, error)
|
||||
ProjectID string
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
@@ -12,14 +12,11 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
package bigquerysearchcatalog_test
|
||||
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
@@ -27,7 +24,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/bigquery/bigquerysearchcatalog"
|
||||
)
|
||||
|
||||
|
||||
func TestParseFromYamlBigQuerySearch(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
@@ -73,4 +69,4 @@ func TestParseFromYamlBigQuerySearch(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,11 +44,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -81,9 +81,9 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
Description: cfg.Description,
|
||||
AllParams: allParameters,
|
||||
BaseURL: s.BaseURL,
|
||||
UserAgent: s.UserAgent,
|
||||
AllParams: allParameters,
|
||||
BaseURL: s.BaseURL,
|
||||
UserAgent: s.UserAgent,
|
||||
Client: s.Client,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest()},
|
||||
mcpManifest: tools.McpManifest{Name: cfg.Name, Description: cfg.Description, InputSchema: allParameters.McpManifest()},
|
||||
@@ -94,12 +94,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
UserAgent string
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
UserAgent string
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mssql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
)
|
||||
|
||||
@@ -302,11 +302,11 @@ var _ compatibleSource = &mssql.Source{}
|
||||
var compatibleSources = [...]string{cloudsqlmssql.SourceKind, mssql.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -344,13 +344,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.MSSQLDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Db: s.MSSQLDB(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -359,10 +359,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Db *sql.DB
|
||||
manifest tools.Manifest
|
||||
@@ -373,52 +373,52 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
|
||||
namedArgs := []any{
|
||||
sql.Named("table_names", paramsMap["table_names"]),
|
||||
sql.Named("output_format", outputFormat),
|
||||
}
|
||||
sql.Named("table_names", paramsMap["table_names"]),
|
||||
sql.Named("output_format", outputFormat),
|
||||
}
|
||||
|
||||
rows, err := t.Db.QueryContext(ctx, listTablesStatement, namedArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
rows, err := t.Db.QueryContext(ctx, listTablesStatement, namedArgs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch column names: %w", err)
|
||||
}
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch column names: %w", err)
|
||||
}
|
||||
|
||||
// create an array of values for each column, which can be re-used to scan each row
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
rawValues := make([]any, len(cols))
|
||||
values := make([]any, len(cols))
|
||||
for i := range rawValues {
|
||||
values[i] = &rawValues[i]
|
||||
}
|
||||
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
err = rows.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
vMap[name] = rawValues[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
var out []any
|
||||
for rows.Next() {
|
||||
err = rows.Scan(values...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
vMap := make(map[string]any)
|
||||
for i, name := range cols {
|
||||
vMap[name] = rawValues[i]
|
||||
}
|
||||
out = append(out, vMap)
|
||||
}
|
||||
|
||||
// Check if error occurred during iteration
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("errors encountered during row iteration: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
|
||||
@@ -73,4 +73,4 @@ func TestParseFromYamlmssqlListTables(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,9 +22,9 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/mysql"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/mysql/mysqlcommon"
|
||||
)
|
||||
|
||||
const kind string = "mysql-list-tables"
|
||||
@@ -208,11 +208,11 @@ var _ compatibleSource = &mysql.Source{}
|
||||
var compatibleSources = [...]string{cloudsqlmysql.SourceKind, mysql.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -250,13 +250,13 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
|
||||
// finish tool setup
|
||||
t := Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.MySQLPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
AllParams: allParameters,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Pool: s.MySQLPool(),
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
@@ -265,10 +265,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *sql.DB
|
||||
manifest tools.Manifest
|
||||
@@ -283,9 +283,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid '%s' parameter; expected a string", tableNames)
|
||||
}
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, listTablesStatement, tableNames, outputFormat)
|
||||
if err != nil {
|
||||
@@ -357,4 +357,4 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,4 +72,4 @@ func TestParseFromYamlMySQLListTables(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,10 +190,10 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
Pool *pgxpool.Pool
|
||||
manifest tools.Manifest
|
||||
@@ -208,9 +208,9 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string")
|
||||
}
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
if outputFormat != "simple" && outputFormat != "detailed" {
|
||||
return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat)
|
||||
}
|
||||
|
||||
results, err := t.Pool.Query(ctx, listTablesStatement, tableNames, outputFormat)
|
||||
if err != nil {
|
||||
@@ -258,4 +258,4 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,19 +172,19 @@ func (t Tool) getStatement() string {
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
|
||||
// Get the appropriate SQL statement based on dialect
|
||||
statement := t.getStatement()
|
||||
|
||||
// Prepare parameters based on dialect
|
||||
var stmtParams map[string]interface{}
|
||||
|
||||
tableNames, _ := paramsMap["table_names"].(string)
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat == "" {
|
||||
outputFormat = "detailed"
|
||||
}
|
||||
|
||||
|
||||
tableNames, _ := paramsMap["table_names"].(string)
|
||||
outputFormat, _ := paramsMap["output_format"].(string)
|
||||
if outputFormat == "" {
|
||||
outputFormat = "detailed"
|
||||
}
|
||||
|
||||
switch strings.ToLower(t.dialect) {
|
||||
case "postgresql":
|
||||
// PostgreSQL uses positional parameters ($1, $2)
|
||||
@@ -192,7 +192,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
"p1": tableNames,
|
||||
"p2": outputFormat,
|
||||
}
|
||||
|
||||
|
||||
case "googlesql":
|
||||
// GoogleSQL uses named parameters (@table_names, @output_format)
|
||||
stmtParams = map[string]interface{}{
|
||||
@@ -203,10 +203,10 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("unsupported dialect: %s", t.dialect)
|
||||
}
|
||||
|
||||
stmt := spanner.Statement{
|
||||
SQL: statement,
|
||||
Params: stmtParams,
|
||||
}
|
||||
stmt := spanner.Statement{
|
||||
SQL: statement,
|
||||
Params: stmtParams,
|
||||
}
|
||||
|
||||
// Execute the query (read-only)
|
||||
iter := t.Client.Single().Query(ctx, stmt)
|
||||
|
||||
@@ -20,13 +20,13 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestParseFromYamlExecuteSql(t *testing.T) {
|
||||
|
||||
@@ -20,13 +20,13 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestParseFromYamlSQLite(t *testing.T) {
|
||||
|
||||
@@ -35,15 +35,14 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
|
||||
)
|
||||
|
||||
var (
|
||||
AlloyDBProject = os.Getenv("ALLOYDB_PROJECT")
|
||||
AlloyDBLocation = os.Getenv("ALLOYDB_REGION")
|
||||
AlloyDBCluster = os.Getenv("ALLOYDB_CLUSTER")
|
||||
AlloyDBInstance = os.Getenv("ALLOYDB_INSTANCE")
|
||||
AlloyDBUser = os.Getenv("ALLOYDB_POSTGRES_USER")
|
||||
AlloyDBProject = os.Getenv("ALLOYDB_PROJECT")
|
||||
AlloyDBLocation = os.Getenv("ALLOYDB_REGION")
|
||||
AlloyDBCluster = os.Getenv("ALLOYDB_CLUSTER")
|
||||
AlloyDBInstance = os.Getenv("ALLOYDB_INSTANCE")
|
||||
AlloyDBUser = os.Getenv("ALLOYDB_POSTGRES_USER")
|
||||
)
|
||||
|
||||
func getAlloyDBVars(t *testing.T) map[string]string {
|
||||
@@ -258,8 +257,8 @@ func runAlloyDBMCPToolCallMethod(t *testing.T, vars map[string]string) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantContains: fmt.Sprintf(`"name\":\"projects/%s/locations/%s/clusters/%s\"`, vars["project"], vars["location"], vars["cluster"]),
|
||||
isErr: false,
|
||||
wantContains: fmt.Sprintf(`"name\":\"projects/%s/locations/%s/clusters/%s\"`, vars["project"], vars["location"], vars["cluster"]),
|
||||
isErr: false,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-fail-tool",
|
||||
@@ -274,8 +273,8 @@ func runAlloyDBMCPToolCallMethod(t *testing.T, vars map[string]string) {
|
||||
},
|
||||
},
|
||||
},
|
||||
wantContains: `parameter \"project\" is required`,
|
||||
isErr: true,
|
||||
wantContains: `parameter \"project\" is required`,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke invalid tool",
|
||||
@@ -302,8 +301,8 @@ func runAlloyDBMCPToolCallMethod(t *testing.T, vars map[string]string) {
|
||||
"arguments": map[string]any{"location": vars["location"]},
|
||||
},
|
||||
},
|
||||
wantContains: `parameter \"project\" is required`,
|
||||
isErr: true,
|
||||
wantContains: `parameter \"project\" is required`,
|
||||
isErr: true,
|
||||
},
|
||||
{
|
||||
name: "MCP Invoke my-auth-required-tool",
|
||||
@@ -357,30 +356,30 @@ func runAlloyDBMCPToolCallMethod(t *testing.T, vars map[string]string) {
|
||||
func runAlloyDBListClustersTest(t *testing.T, vars map[string]string) {
|
||||
|
||||
type ListClustersResponse struct {
|
||||
Clusters []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"clusters"`
|
||||
}
|
||||
Clusters []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"clusters"`
|
||||
}
|
||||
|
||||
type ToolResponse struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
// NOTE: If clusters are added, removed or changed in the test project,
|
||||
// this list must be updated for the "list clusters specific locations" test to pass
|
||||
// this list must be updated for the "list clusters specific locations" test to pass
|
||||
wantForSpecificLocation := []string{
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-ai-nl-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-pg-testing", vars["project"]),
|
||||
}
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-ai-nl-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-pg-testing", vars["project"]),
|
||||
}
|
||||
|
||||
// NOTE: If clusters are added, removed, or changed in the test project,
|
||||
// this list must be updated for the "list clusters all locations" test to pass
|
||||
// this list must be updated for the "list clusters all locations" test to pass
|
||||
wantForAllLocations := []string{
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-ai-nl-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-pg-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/alloydb-private-pg-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/colab-testing", vars["project"]),
|
||||
}
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-ai-nl-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-pg-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/alloydb-private-pg-testing", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/colab-testing", vars["project"]),
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
@@ -496,7 +495,7 @@ func runAlloyDBListUsersTest(t *testing.T, vars map[string]string) {
|
||||
name: "list users success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])),
|
||||
wantContains: fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", vars["project"], vars["location"], vars["cluster"], AlloyDBUser),
|
||||
wantCount: 3, // NOTE: If users are added or removed in the test project, update the number of users here must be updated for this test to pass
|
||||
wantCount: 3, // NOTE: If users are added or removed in the test project, update the number of users here must be updated for this test to pass
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
@@ -613,7 +612,7 @@ func runAlloyDBListInstancesTest(t *testing.T, vars map[string]string) {
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-ai-nl-testing/instances/alloydb-ai-nl-testing-instance", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-central1/clusters/alloydb-pg-testing/instances/alloydb-pg-testing-instance", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/alloydb-private-pg-testing/instances/alloydb-private-pg-testing-instance", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/colab-testing/instances/colab-testing-primary", vars["project"]),
|
||||
fmt.Sprintf("projects/%s/locations/us-east4/clusters/colab-testing/instances/colab-testing-primary", vars["project"]),
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
@@ -721,12 +720,12 @@ func runAlloyDBGetClusterTest(t *testing.T, vars map[string]string) {
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "get cluster success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])),
|
||||
want: map[string]any{
|
||||
"clusterType": "PRIMARY",
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s", vars["project"], vars["location"], vars["cluster"]),
|
||||
},
|
||||
name: "get cluster success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s"}`, vars["project"], vars["location"], vars["cluster"])),
|
||||
want: map[string]any{
|
||||
"clusterType": "PRIMARY",
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s", vars["project"], vars["location"], vars["cluster"]),
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
@@ -811,12 +810,12 @@ func runAlloyDBGetInstanceTest(t *testing.T, vars map[string]string) {
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "get instance success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "instance": "%s"}`, vars["project"], vars["location"], vars["cluster"], vars["instance"])),
|
||||
want: map[string]any{
|
||||
"instanceType": "PRIMARY",
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", vars["project"], vars["location"], vars["cluster"], vars["instance"]),
|
||||
},
|
||||
name: "get instance success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "instance": "%s"}`, vars["project"], vars["location"], vars["cluster"], vars["instance"])),
|
||||
want: map[string]any{
|
||||
"instanceType": "PRIMARY",
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", vars["project"], vars["location"], vars["cluster"], vars["instance"]),
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
@@ -906,12 +905,12 @@ func runAlloyDBGetUserTest(t *testing.T, vars map[string]string) {
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "get user success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "user": "%s"}`, vars["project"], vars["location"], vars["cluster"], vars["user"])),
|
||||
want: map[string]any{
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", vars["project"], vars["location"], vars["cluster"], vars["user"]),
|
||||
"userType": "ALLOYDB_BUILT_IN",
|
||||
},
|
||||
name: "get user success",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"project": "%s", "location": "%s", "cluster": "%s", "user": "%s"}`, vars["project"], vars["location"], vars["cluster"], vars["user"])),
|
||||
want: map[string]any{
|
||||
"name": fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", vars["project"], vars["location"], vars["cluster"], vars["user"]),
|
||||
"userType": "ALLOYDB_BUILT_IN",
|
||||
},
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
@@ -1003,7 +1002,7 @@ func (t *mockAlloyDBTransport) RoundTrip(req *http.Request) (*http.Response, err
|
||||
}
|
||||
|
||||
type mockAlloyDBHandler struct {
|
||||
t *testing.T
|
||||
t *testing.T
|
||||
idParam string
|
||||
}
|
||||
|
||||
@@ -1122,9 +1121,9 @@ func TestAlloyDBCreateCluster(t *testing.T) {
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
name string
|
||||
body string
|
||||
want string
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
@@ -1330,7 +1329,7 @@ func TestAlloyDBCreateInstance(t *testing.T) {
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("unexpected result:\n- want: %+v\n- got: %+v", want, got)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ type operation struct {
|
||||
type handler struct {
|
||||
mu sync.Mutex
|
||||
operations map[string]*operation
|
||||
t *testing.T
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -256,4 +256,4 @@ func getWaitToolsConfig() map[string]any {
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,4 +272,4 @@ func TestAlloyDBPgIAMConnection(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,9 +50,9 @@ func TestTool_Invoke(t *testing.T) {
|
||||
Name: "test-cloudmonitoring",
|
||||
Kind: "cloud-monitoring-query-prometheus",
|
||||
Description: "Test Cloudmonitoring Tool",
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
}
|
||||
|
||||
// Define the test parameters
|
||||
@@ -94,9 +94,9 @@ func TestTool_Invoke_Error(t *testing.T) {
|
||||
Name: "test-cloudmonitoring",
|
||||
Kind: "clou-monitoring-query-prometheus",
|
||||
Description: "Test Cloudmonitoring Tool",
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
AllParams: tools.Parameters{},
|
||||
BaseURL: server.URL,
|
||||
Client: &http.Client{},
|
||||
}
|
||||
|
||||
// Define the test parameters
|
||||
@@ -110,4 +110,4 @@ func TestTool_Invoke_Error(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("Invoke() error = nil, want error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ func AddMySQLPrebuiltToolConfig(t *testing.T, config map[string]any) map[string]
|
||||
"kind": "mysql-list-tables-missing-unique-indexes",
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables that do not have primary or unique indexes in the database.",
|
||||
}
|
||||
}
|
||||
tools["list_table_fragmentation"] = map[string]any{
|
||||
"kind": "mysql-list-table-fragmentation",
|
||||
"source": "my-instance",
|
||||
|
||||
@@ -37,14 +37,14 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
MSSQLSourceKind = "mssql"
|
||||
MSSQLToolKind = "mssql-sql"
|
||||
MSSQLSourceKind = "mssql"
|
||||
MSSQLToolKind = "mssql-sql"
|
||||
MSSQLListTablesToolKind = "mssql-list-tables"
|
||||
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
|
||||
MSSQLHost = os.Getenv("MSSQL_HOST")
|
||||
MSSQLPort = os.Getenv("MSSQL_PORT")
|
||||
MSSQLUser = os.Getenv("MSSQL_USER")
|
||||
MSSQLPass = os.Getenv("MSSQL_PASS")
|
||||
MSSQLDatabase = os.Getenv("MSSQL_DATABASE")
|
||||
MSSQLHost = os.Getenv("MSSQL_HOST")
|
||||
MSSQLPort = os.Getenv("MSSQL_PORT")
|
||||
MSSQLUser = os.Getenv("MSSQL_USER")
|
||||
MSSQLPass = os.Getenv("MSSQL_PASS")
|
||||
)
|
||||
|
||||
func getMsSQLVars(t *testing.T) map[string]any {
|
||||
@@ -234,13 +234,13 @@ func runMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
|
||||
{
|
||||
name: "invoke list_tables with invalid output format",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: fmt.Sprintf(`{"table_names": "", "output_format": "abcd"}`),
|
||||
requestBody: `{"table_names": "", "output_format": "abcd"}`,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invoke list_tables with malformed table_names parameter",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: fmt.Sprintf(`{"table_names": 12345, "output_format": "detailed"}`),
|
||||
requestBody: `{"table_names": 12345, "output_format": "detailed"}`,
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
@@ -253,7 +253,7 @@ func runMSSQLListTablesTest(t *testing.T, tableNameParam, tableNameAuth string)
|
||||
{
|
||||
name: "invoke list_tables with non-existent table",
|
||||
api: "http://127.0.0.1:5000/api/tool/list_tables/invoke",
|
||||
requestBody: fmt.Sprintf(`{"table_names": "non_existent_table"}`),
|
||||
requestBody: `{"table_names": "non_existent_table"}`,
|
||||
wantStatusCode: http.StatusOK,
|
||||
want: `null`,
|
||||
},
|
||||
|
||||
@@ -137,6 +137,6 @@ func TestMySQLToolEndpoints(t *testing.T) {
|
||||
// Run specific MySQL tool tests
|
||||
tests.RunMySQLListTablesTest(t, MySQLDatabase, tableNameParam, tableNameAuth)
|
||||
tests.RunMySQLListActiveQueriesTest(t, ctx, pool)
|
||||
tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MySQLDatabase);
|
||||
tests.RunMySQLListTablesMissingUniqueIndexes(t, ctx, pool, MySQLDatabase)
|
||||
tests.RunMySQLListTableFragmentationTest(t, MySQLDatabase, tableNameParam, tableNameAuth)
|
||||
}
|
||||
|
||||
@@ -311,14 +311,14 @@ func addSpannerListTablesConfig(t *testing.T, config map[string]any) map[string]
|
||||
if !ok {
|
||||
t.Fatalf("unable to get tools from config")
|
||||
}
|
||||
|
||||
|
||||
// Add spanner-list-tables tool
|
||||
tools["list-tables-tool"] = map[string]any{
|
||||
"kind": "spanner-list-tables",
|
||||
"source": "my-instance",
|
||||
"description": "Lists tables with their schema information",
|
||||
}
|
||||
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
@@ -547,7 +547,6 @@ func runSpannerExecuteSqlToolInvokeTest(t *testing.T, select1Want, invokeParamWa
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Helper function to verify table list results
|
||||
func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTables []string, expectedSimpleFormat bool) {
|
||||
// Parse the result
|
||||
@@ -555,13 +554,13 @@ func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTa
|
||||
if !ok {
|
||||
t.Fatalf("unable to find result in response body")
|
||||
}
|
||||
|
||||
|
||||
var tables []interface{}
|
||||
err := json.Unmarshal([]byte(result), &tables)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to parse result as JSON array: %s", err)
|
||||
}
|
||||
|
||||
|
||||
// If we expect specific tables, verify they exist
|
||||
if len(expectedTables) > 0 {
|
||||
tableNames := make(map[string]bool)
|
||||
@@ -575,7 +574,7 @@ func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTa
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
// Parse object_details JSON string into map[string]interface{}
|
||||
if objectDetailsStr, ok := tableMap["object_details"].(string); ok {
|
||||
var objectDetails map[string]interface{}
|
||||
@@ -586,7 +585,7 @@ func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTa
|
||||
|
||||
for _, reqKey := range requiredKeys {
|
||||
if _, hasKey := objectDetails[reqKey]; !hasKey {
|
||||
t.Errorf("missing required key '%s', for object_details: %v",reqKey, objectDetails)
|
||||
t.Errorf("missing required key '%s', for object_details: %v", reqKey, objectDetails)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -595,7 +594,7 @@ func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTa
|
||||
tableNames[name] = true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
for _, expected := range expectedTables {
|
||||
if !tableNames[expected] {
|
||||
t.Errorf("expected table %s not found in results", expected)
|
||||
@@ -607,47 +606,46 @@ func verifyTableListResult(t *testing.T, body map[string]interface{}, expectedTa
|
||||
// runSpannerListTablesTest tests the spanner-list-tables tool
|
||||
func runSpannerListTablesTest(t *testing.T, tableNameParam, tableNameAuth, tableNameTemplateParam string) {
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
expectedTables []string // empty means don't check specific tables
|
||||
useSimpleFormat bool
|
||||
name string
|
||||
requestBody io.Reader
|
||||
expectedTables []string // empty means don't check specific tables
|
||||
useSimpleFormat bool
|
||||
}{
|
||||
{
|
||||
name: "list all tables with detailed format",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth, tableNameTemplateParam},
|
||||
name: "list all tables with detailed format",
|
||||
requestBody: bytes.NewBuffer([]byte(`{}`)),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth, tableNameTemplateParam},
|
||||
},
|
||||
{
|
||||
name: "list tables with simple format",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"output_format": "simple"}`)),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth, tableNameTemplateParam},
|
||||
useSimpleFormat: true,
|
||||
name: "list tables with simple format",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"output_format": "simple"}`)),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth, tableNameTemplateParam},
|
||||
useSimpleFormat: true,
|
||||
},
|
||||
{
|
||||
name: "list specific tables",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth},
|
||||
name: "list specific tables",
|
||||
requestBody: bytes.NewBuffer([]byte(fmt.Sprintf(`{"table_names": "%s,%s"}`, tableNameParam, tableNameAuth))),
|
||||
expectedTables: []string{tableNameParam, tableNameAuth},
|
||||
},
|
||||
{
|
||||
name: "list non-existent table",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table_xyz"}`)),
|
||||
expectedTables: []string{},
|
||||
name: "list non-existent table",
|
||||
requestBody: bytes.NewBuffer([]byte(`{"table_names": "non_existent_table_xyz"}`)),
|
||||
expectedTables: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Use RunRequest helper function from tests package
|
||||
url := "http://127.0.0.1:5000/api/tool/list-tables-tool/invoke"
|
||||
headers := map[string]string{}
|
||||
|
||||
|
||||
resp, respBody := tests.RunRequest(t, http.MethodPost, url, tc.requestBody, headers)
|
||||
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
|
||||
// Check response body
|
||||
var body map[string]interface{}
|
||||
err := json.Unmarshal(respBody, &body)
|
||||
|
||||
@@ -220,7 +220,7 @@ func TestSQLiteExecuteSqlTool(t *testing.T) {
|
||||
name string
|
||||
sql string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantBody string
|
||||
}{
|
||||
{
|
||||
name: "select existing row",
|
||||
|
||||
Reference in New Issue
Block a user