refactor(sources/alloydbadmin, sources/alloydbpg): move source implementation in Invoke() function to Source (#2226)

Move source-related queries from `Invoke()` function into Source.

This is an effort to generalizing tools to work with any Source that
implements a specific interface. This will provide a better segregation
of the roles for Tools vs Source.

Tool's role will be limited to the following:
* Resolve any pre-implementation steps or parameters (e.g. template
parameters)
* Retrieving Source
* Calling the source's implementation


Along with these updates, this PR also resolve some comments from
Gemini:
* update `fmt.Printf()` to logging as a Debug log -- within
`GetOperations()`
* update `fmt.Printf()` during failure to retrieve user agent into
throwing an error. UserAgent are expected to be retrieved successfully
during source initialization. Failure to retrieve will indicate a server
error.
This commit is contained in:
Yuan Teoh
2025-12-24 01:09:22 -08:00
committed by GitHub
parent 9695fc5eeb
commit 0202709efc
13 changed files with 348 additions and 299 deletions

View File

@@ -15,8 +15,12 @@ package alloydbadmin
import (
"context"
"encoding/json"
"fmt"
"html/template"
"net/http"
"strings"
"time"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -61,7 +65,7 @@ func (r Config) SourceConfigKind() string {
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
fmt.Printf("Error in User Agent retrieval: %s", err)
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
@@ -114,7 +118,7 @@ func (s *Source) GetDefaultProject() string {
return s.DefaultProject
}
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
func (s *Source) getService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
if s.UseClientOAuth {
token := &oauth2.Token{AccessToken: accessToken}
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
@@ -130,3 +134,287 @@ func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbre
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) CreateCluster(ctx context.Context, project, location, network, user, password, cluster, accessToken string) (any, error) {
// Build the request body using the type-safe Cluster struct.
clusterBody := &alloydbrestapi.Cluster{
NetworkConfig: &alloydbrestapi.NetworkConfig{
Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network),
},
InitialUser: &alloydbrestapi.UserPassword{
User: user,
Password: password,
},
}
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(cluster).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err)
}
return resp, nil
}
func (s *Source) CreateInstance(ctx context.Context, project, location, cluster, instanceID, instanceType, displayName string, nodeCount int, accessToken string) (any, error) {
// Build the request body using the type-safe Instance struct.
instance := &alloydbrestapi.Instance{
InstanceType: instanceType,
NetworkConfig: &alloydbrestapi.InstanceNetworkConfig{
EnablePublicIp: true,
},
DatabaseFlags: map[string]string{
"password.enforce_complexity": "on",
},
}
if displayName != "" {
instance.DisplayName = displayName
}
if instanceType == "READ_POOL" {
instance.ReadPoolConfig = &alloydbrestapi.ReadPoolConfig{
NodeCount: int64(nodeCount),
}
}
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Instances.Create(urlString, instance).InstanceId(instanceID).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB instance: %w", err)
}
return resp, nil
}
func (s *Source) CreateUser(ctx context.Context, userType, password string, roles []string, accessToken, project, location, cluster, userID string) (any, error) {
// Build the request body using the type-safe User struct.
user := &alloydbrestapi.User{
UserType: userType,
}
if userType == "ALLOYDB_BUILT_IN" {
user.Password = password
}
if len(roles) > 0 {
user.DatabaseRoles = roles
}
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Users.Create(urlString, user).UserId(userID).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB user: %w", err)
}
return resp, nil
}
func (s *Source) GetCluster(ctx context.Context, project, location, cluster, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB cluster: %w", err)
}
return resp, nil
}
func (s *Source) GetInstance(ctx context.Context, project, location, cluster, instance, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, location, cluster, instance)
resp, err := service.Projects.Locations.Clusters.Instances.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB instance: %w", err)
}
return resp, nil
}
func (s *Source) GetUsers(ctx context.Context, project, location, cluster, user, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", project, location, cluster, user)
resp, err := service.Projects.Locations.Clusters.Users.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB user: %w", err)
}
return resp, nil
}
func (s *Source) ListCluster(ctx context.Context, project, location, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
resp, err := service.Projects.Locations.Clusters.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB clusters: %w", err)
}
return resp, nil
}
func (s *Source) ListInstance(ctx context.Context, project, location, cluster, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Instances.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB instances: %w", err)
}
return resp, nil
}
func (s *Source) ListUsers(ctx context.Context, project, location, cluster, accessToken string) (any, error) {
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Users.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB users: %w", err)
}
return resp, nil
}
func (s *Source) GetOperations(ctx context.Context, project, location, operation, connectionMessageTemplate string, delay time.Duration, accessToken string) (any, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, err
}
service, err := s.getService(ctx, accessToken)
if err != nil {
return nil, err
}
name := fmt.Sprintf("projects/%s/locations/%s/operations/%s", project, location, operation)
op, err := service.Projects.Locations.Operations.Get(name).Do()
if err != nil {
logger.DebugContext(ctx, fmt.Sprintf("error getting operation: %s, retrying in %v\n", err, delay))
} else {
if op.Done {
if op.Error != nil {
var errorBytes []byte
errorBytes, err = json.Marshal(op.Error)
if err != nil {
return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err)
}
return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes))
}
var opBytes []byte
opBytes, err = op.MarshalJSON()
if err != nil {
return nil, fmt.Errorf("could not marshal operation: %w", err)
}
if op.Response != nil {
var responseData map[string]any
if err := json.Unmarshal(op.Response, &responseData); err == nil && responseData != nil {
if msg, ok := generateAlloyDBConnectionMessage(responseData, connectionMessageTemplate); ok {
return msg, nil
}
}
}
return string(opBytes), nil
}
logger.DebugContext(ctx, fmt.Sprintf("Operation not complete, retrying in %v\n", delay))
}
return nil, nil
}
func generateAlloyDBConnectionMessage(responseData map[string]any, connectionMessageTemplate string) (string, bool) {
resourceName, ok := responseData["name"].(string)
if !ok {
return "", false
}
parts := strings.Split(resourceName, "/")
var project, region, cluster, instance string
// Expected format: projects/{project}/locations/{location}/clusters/{cluster}
// or projects/{project}/locations/{location}/clusters/{cluster}/instances/{instance}
if len(parts) < 6 || parts[0] != "projects" || parts[2] != "locations" || parts[4] != "clusters" {
return "", false
}
project = parts[1]
region = parts[3]
cluster = parts[5]
if len(parts) >= 8 && parts[6] == "instances" {
instance = parts[7]
} else {
return "", false
}
tmpl, err := template.New("alloydb-connection").Parse(connectionMessageTemplate)
if err != nil {
// This should not happen with a static template
return fmt.Sprintf("template parsing error: %v", err), false
}
data := struct {
Project string
Region string
Cluster string
Instance string
}{
Project: project,
Region: region,
Cluster: cluster,
Instance: instance,
}
var b strings.Builder
if err := tmpl.Execute(&b, data); err != nil {
return fmt.Sprintf("template execution error: %v", err), false
}
return b.String(), true
}

View File

@@ -101,6 +101,33 @@ func (s *Source) PostgresPool() *pgxpool.Pool {
return s.Pool
}
func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.Pool.Query(ctx, statement, params...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, statement, params)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
// this will catch actual query execution errors
if err := results.Err(); err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
return out, nil
}
func getOpts(ipType, userAgent string, useIAM bool) ([]alloydbconn.Option, error) {
opts := []alloydbconn.Option{alloydbconn.WithUserAgent(userAgent)}
switch strings.ToLower(ipType) {

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-create-cluster"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
CreateCluster(context.Context, string, string, string, string, string, string, string) (any, error)
}
// Configuration for the create-cluster tool.
@@ -159,31 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
// Build the request body using the type-safe Cluster struct.
clusterBody := &alloydb.Cluster{
NetworkConfig: &alloydb.NetworkConfig{
Network: fmt.Sprintf("projects/%s/global/networks/%s", project, network),
},
InitialUser: &alloydb.UserPassword{
User: user,
Password: password,
},
}
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Create(urlString, clusterBody).ClusterId(clusterID).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB cluster: %w", err)
}
return resp, nil
return source.CreateCluster(ctx, project, location, network, user, password, clusterID, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-create-instance"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
CreateInstance(context.Context, string, string, string, string, string, string, int, string) (any, error)
}
// Configuration for the create-instance tool.
@@ -155,45 +154,17 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instanceType' parameter; expected 'PRIMARY' or 'READ_POOL'")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
// Build the request body using the type-safe Instance struct.
instance := &alloydb.Instance{
InstanceType: instanceType,
NetworkConfig: &alloydb.InstanceNetworkConfig{
EnablePublicIp: true,
},
DatabaseFlags: map[string]string{
"password.enforce_complexity": "on",
},
}
if displayName, ok := paramsMap["displayName"].(string); ok && displayName != "" {
instance.DisplayName = displayName
}
displayName, _ := paramsMap["displayName"].(string)
var nodeCount int
if instanceType == "READ_POOL" {
nodeCount, ok := paramsMap["nodeCount"].(int)
nodeCount, ok = paramsMap["nodeCount"].(int)
if !ok {
return nil, fmt.Errorf("invalid 'nodeCount' parameter; expected an integer for READ_POOL")
}
instance.ReadPoolConfig = &alloydb.ReadPoolConfig{
NodeCount: int64(nodeCount),
}
}
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Instances.Create(urlString, instance).InstanceId(instanceID).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB instance: %w", err)
}
return resp, nil
return source.CreateInstance(ctx, project, location, cluster, instanceID, instanceType, displayName, nodeCount, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-create-user"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
CreateUser(context.Context, string, string, []string, string, string, string, string, string) (any, error)
}
// Configuration for the create-user tool.
@@ -153,46 +152,24 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if !ok || (userType != "ALLOYDB_BUILT_IN" && userType != "ALLOYDB_IAM_USER") {
return nil, fmt.Errorf("invalid or missing 'userType' parameter; expected 'ALLOYDB_BUILT_IN' or 'ALLOYDB_IAM_USER'")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
// Build the request body using the type-safe User struct.
user := &alloydb.User{
UserType: userType,
}
var password string
if userType == "ALLOYDB_BUILT_IN" {
password, ok := paramsMap["password"].(string)
password, ok = paramsMap["password"].(string)
if !ok || password == "" {
return nil, fmt.Errorf("password is required when userType is ALLOYDB_BUILT_IN")
}
user.Password = password
}
var roles []string
if dbRolesRaw, ok := paramsMap["databaseRoles"].([]any); ok && len(dbRolesRaw) > 0 {
var roles []string
for _, r := range dbRolesRaw {
if role, ok := r.(string); ok {
roles = append(roles, role)
}
}
if len(roles) > 0 {
user.DatabaseRoles = roles
}
}
// The Create API returns a long-running operation.
resp, err := service.Projects.Locations.Clusters.Users.Create(urlString, user).UserId(userID).Do()
if err != nil {
return nil, fmt.Errorf("error creating AlloyDB user: %w", err)
}
return resp, nil
return source.CreateUser(ctx, userType, password, roles, string(accessToken), project, location, cluster, userID)
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-cluster"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
GetCluster(context.Context, string, string, string, string) (any, error)
}
// Configuration for the get-cluster tool.
@@ -141,19 +140,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB cluster: %w", err)
}
return resp, nil
return source.GetCluster(ctx, project, location, cluster, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-instance"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
GetInstance(context.Context, string, string, string, string, string) (any, error)
}
// Configuration for the get-instance tool.
@@ -145,19 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'instance' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/instances/%s", project, location, cluster, instance)
resp, err := service.Projects.Locations.Clusters.Instances.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB instance: %w", err)
}
return resp, nil
return source.GetInstance(ctx, project, location, cluster, instance, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-get-user"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
GetUsers(context.Context, string, string, string, string, string) (any, error)
}
// Configuration for the get-user tool.
@@ -145,19 +144,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'user' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s/users/%s", project, location, cluster, user)
resp, err := service.Projects.Locations.Clusters.Users.Get(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error getting AlloyDB user: %w", err)
}
return resp, nil
return source.GetUsers(ctx, project, location, cluster, user, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-clusters"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
ListCluster(context.Context, string, string, string) (any, error)
}
// Configuration for the list-clusters tool.
@@ -135,19 +134,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'location' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s", project, location)
resp, err := service.Projects.Locations.Clusters.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB clusters: %w", err)
}
return resp, nil
return source.ListCluster(ctx, project, location, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-instances"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
ListInstance(context.Context, string, string, string, string) (any, error)
}
// Configuration for the list-instances tool.
@@ -140,19 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Instances.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB instances: %w", err)
}
return resp, nil
return source.ListInstance(ctx, project, location, cluster, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-list-users"
@@ -44,7 +43,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
ListUsers(context.Context, string, string, string, string) (any, error)
}
// Configuration for the list-users tool.
@@ -140,19 +139,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("invalid 'cluster' parameter; expected a string")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", project, location, cluster)
resp, err := service.Projects.Locations.Clusters.Users.List(urlString).Do()
if err != nil {
return nil, fmt.Errorf("error listing AlloyDB users: %w", err)
}
return resp, nil
return source.ListUsers(ctx, project, location, cluster, string(accessToken))
}
// ParseParams parses the parameters for the tool.

View File

@@ -16,18 +16,14 @@ package alloydbwaitforoperation
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"text/template"
"time"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/api/alloydb/v1"
)
const kind string = "alloydb-wait-for-operation"
@@ -92,7 +88,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetDefaultProject() string
UseClientAuthorization() bool
GetService(context.Context, string) (*alloydb.Service, error)
GetOperations(context.Context, string, string, string, string, time.Duration, string) (any, error)
}
// Config defines the configuration for the wait-for-operation tool.
@@ -237,16 +233,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("missing 'operation' parameter")
}
service, err := source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
name := fmt.Sprintf("projects/%s/locations/%s/operations/%s", project, location, operation)
delay := t.Delay
maxDelay := t.MaxDelay
multiplier := t.Multiplier
@@ -260,36 +249,11 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
default:
}
op, err := service.Projects.Locations.Operations.Get(name).Do()
op, err := source.GetOperations(ctx, project, location, operation, alloyDBConnectionMessageTemplate, delay, string(accessToken))
if err != nil {
fmt.Printf("error getting operation: %s, retrying in %v\n", err, delay)
} else {
if op.Done {
if op.Error != nil {
var errorBytes []byte
errorBytes, err = json.Marshal(op.Error)
if err != nil {
return nil, fmt.Errorf("operation finished with error but could not marshal error object: %w", err)
}
return nil, fmt.Errorf("operation finished with error: %s", string(errorBytes))
}
var opBytes []byte
opBytes, err = op.MarshalJSON()
if err != nil {
return nil, fmt.Errorf("could not marshal operation: %w", err)
}
var responseData map[string]any
if err := json.Unmarshal(op.Response, &responseData); err == nil && responseData != nil {
if msg, ok := t.generateAlloyDBConnectionMessage(responseData); ok {
return msg, nil
}
}
return string(opBytes), nil
}
fmt.Printf("Operation not complete, retrying in %v\n", delay)
return nil, err
} else if op != nil {
return op, nil
}
time.Sleep(delay)
@@ -302,57 +266,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, fmt.Errorf("exceeded max retries waiting for operation")
}
func (t Tool) generateAlloyDBConnectionMessage(responseData map[string]any) (string, bool) {
resourceName, ok := responseData["name"].(string)
if !ok {
return "", false
}
parts := strings.Split(resourceName, "/")
var project, region, cluster, instance string
// Expected format: projects/{project}/locations/{location}/clusters/{cluster}
// or projects/{project}/locations/{location}/clusters/{cluster}/instances/{instance}
if len(parts) < 6 || parts[0] != "projects" || parts[2] != "locations" || parts[4] != "clusters" {
return "", false
}
project = parts[1]
region = parts[3]
cluster = parts[5]
if len(parts) >= 8 && parts[6] == "instances" {
instance = parts[7]
} else {
return "", false
}
tmpl, err := template.New("alloydb-connection").Parse(alloyDBConnectionMessageTemplate)
if err != nil {
// This should not happen with a static template
return fmt.Sprintf("template parsing error: %v", err), false
}
data := struct {
Project string
Region string
Cluster string
Instance string
}{
Project: project,
Region: region,
Cluster: cluster,
Instance: instance,
}
var b strings.Builder
if err := tmpl.Execute(&b, data); err != nil {
return fmt.Sprintf("template execution error: %v", err), false
}
return b.String(), true
}
// ParseParams parses the parameters for the tool.
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.AllParams, data, claims)

View File

@@ -44,6 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
PostgresPool() *pgxpool.Pool
RunSQL(context.Context, string, []any) (any, error)
}
type Config struct {
@@ -135,7 +136,6 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
if err != nil {
return nil, err
}
pool := source.PostgresPool()
sliceParams := params.AsSlice()
allParamValues := make([]any, len(sliceParams)+1)
@@ -145,31 +145,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
allParamValues[i+2] = fmt.Sprintf("%s", param)
}
results, err := pool.Query(ctx, t.Statement, allParamValues...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w. Query: %v , Values: %v. Toolbox v0.19.0+ is only compatible with AlloyDB AI NL v1.0.3+. Please ensure that you are using the latest AlloyDB AI NL extension", err, t.Statement, allParamValues)
}
fields := results.FieldDescriptions()
var out []any
for results.Next() {
v, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
vMap := make(map[string]any)
for i, f := range fields {
vMap[f.Name] = v[i]
}
out = append(out, vMap)
}
// this will catch actual query execution errors
if err := results.Err(); err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
return out, nil
return source.RunSQL(ctx, t.Statement, allParamValues)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {