Compare commits

..

1 Commits

Author SHA1 Message Date
Yuan Teoh
dce95db79d chore: update config after merging to main 2026-01-26 20:17:54 -08:00
5 changed files with 21 additions and 25 deletions

View File

@@ -26,7 +26,7 @@ import (
"google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-restore-backup"
const resourceType string = "cloud-sql-restore-backup"
var _ tools.ToolConfig = Config{}
@@ -40,15 +40,15 @@ type compatibleSource interface {
// Config defines the configuration for the restore-backup tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Type string `yaml:"type" validate:"required"`
Description string `yaml:"description"`
Source string `yaml:"source" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
if !tools.Register(resourceType, newConfig) {
panic(fmt.Sprintf("tool type %q already registered", resourceType))
}
}
@@ -60,9 +60,9 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
return actual, nil
}
// ToolConfigKind returns the kind of the tool.
func (cfg Config) ToolConfigKind() string {
return kind
// ToolConfigType returns the type of the tool.
func (cfg Config) ToolConfigType() string {
return resourceType
}
// Initialize initializes the tool from the configuration.
@@ -73,7 +73,7 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", kind, cfg.Source)
return nil, fmt.Errorf("invalid source for %q tool: source %q not compatible", resourceType, cfg.Source)
}
project := s.GetDefaultProject()
@@ -121,7 +121,7 @@ func (t Tool) ToConfig() tools.ToolConfig {
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
if err != nil {
return nil, err
}
@@ -170,7 +170,7 @@ func (t Tool) Authorized(verifiedAuthServices []string) bool {
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Kind)
source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type)
if err != nil {
return false, err
}

View File

@@ -17,7 +17,6 @@ package cloudsqlrestorebackup_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
@@ -37,16 +36,16 @@ func TestParseFromYaml(t *testing.T) {
{
desc: "basic example",
in: `
tools:
restore-backup-tool:
kind: cloud-sql-restore-backup
description: a test description
source: a-source
kind: tools
name: restore-backup-tool
type: cloud-sql-restore-backup
description: a test description
source: a-source
`,
want: server.ToolConfigs{
"restore-backup-tool": cloudsqlrestorebackup.Config{
Name: "restore-backup-tool",
Kind: "cloud-sql-restore-backup",
Type: "cloud-sql-restore-backup",
Description: "a test description",
Source: "a-source",
AuthRequired: []string{},
@@ -56,14 +55,11 @@ func TestParseFromYaml(t *testing.T) {
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
_, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in))
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})

View File

@@ -159,7 +159,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, AlloyDBPostgresToolType, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)

View File

@@ -144,7 +144,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, CloudSQLPostgresToolType, insertStmt, searchStmt)
toolsFile = tests.AddPostgresPrebuiltConfig(t, toolsFile)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)

View File

@@ -124,7 +124,7 @@ func TestPostgres(t *testing.T) {
// Add semantic search tool config
insertStmt, searchStmt := tests.GetPostgresVectorSearchStmts(vectorTableName)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolKind, insertStmt, searchStmt)
toolsFile = tests.AddSemanticSearchConfig(t, toolsFile, PostgresToolType, insertStmt, searchStmt)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {