mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-05-02 03:00:36 -04:00
feat(source/alloydb-admin): Add user agent and attach alloydb api in alloydb-admin source (#1448)
## Description --- - This PR introduces a userAgentRoundTripper that prepends our custom user agent to the existing User-Agent header - Moves alloydb api client to `alloydb-admin` source - Updates alloydb control plane tools (`alloydb-get-cluster`, `alloydb-list-clusters`, `alloydb-list-instances`, `alloydb-list-users`) accordingly. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [x] Ensure the tests and linter pass - [x] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here> --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -35,5 +35,5 @@ tools:
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be alloydb-list-instances. | |
|
||||
| source | string | true | The name of an alloydb-admin source. |
|
||||
| source | string | true | The name of an `alloydb-admin` source. |
|
||||
| description | string | true | Description of the tool that is passed to the agent. |
|
||||
@@ -34,5 +34,5 @@ tools:
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be alloydb-list-users. | |
|
||||
| source | string | true | The name of an alloydb-admin source. |
|
||||
| source | string | true | The name of an `alloydb-admin` source. |
|
||||
| description | string | true | Description of the tool that is passed to the agent. |
|
||||
@@ -24,11 +24,32 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/api/option"
|
||||
alloydbrestapi "google.golang.org/api/alloydb/v1"
|
||||
)
|
||||
|
||||
const SourceKind string = "alloydb-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent+" "+ua)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -64,22 +85,36 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = nil
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
creds, err := google.FindDefaultCredentials(ctx, alloydbrestapi.CloudPlatformScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
client = oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
service, err := alloydbrestapi.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new alloydb service: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: "https://alloydb.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
Service: service,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
|
||||
@@ -92,8 +127,7 @@ type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
Service *alloydbrestapi.Service
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
@@ -101,15 +135,17 @@ func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*alloydbrestapi.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
service, err := alloydbrestapi.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new alloydb service: %w", err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
return s.Service, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-get-cluster"
|
||||
@@ -127,21 +125,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid 'clusterId' parameter; expected a string")
|
||||
}
|
||||
|
||||
// Get an authenticated HTTP client from the source
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting authorized client: %w", err)
|
||||
}
|
||||
|
||||
// Create a new AlloyDB service client using the authorized client
|
||||
alloydbService, err := alloydb.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB service: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", projectId, locationId, clusterId)
|
||||
|
||||
resp, err := alloydbService.Projects.Locations.Clusters.Get(urlString).Do()
|
||||
resp, err := service.Projects.Locations.Clusters.Get(urlString).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting AlloyDB cluster: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-clusters"
|
||||
@@ -123,21 +121,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid 'locationId' parameter; expected a string")
|
||||
}
|
||||
|
||||
// Get an authenticated HTTP client from the source
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting authorized client: %w", err)
|
||||
}
|
||||
|
||||
// Create a new AlloyDB service client using the authorized client
|
||||
alloydbService, err := alloydb.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB service: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s", projectId, locationId)
|
||||
|
||||
resp, err := alloydbService.Projects.Locations.Clusters.List(urlString).Do()
|
||||
resp, err := service.Projects.Locations.Clusters.List(urlString).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing AlloyDB clusters: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-instances"
|
||||
@@ -128,21 +126,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid 'clusterId' parameter; expected a string")
|
||||
}
|
||||
|
||||
// Get an authenticated HTTP client from the source
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting authorized client: %w", err)
|
||||
}
|
||||
|
||||
// Create a new AlloyDB service client using the authorized client
|
||||
alloydbService, err := alloydb.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB service: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", projectId, locationId, clusterId)
|
||||
|
||||
resp, err := alloydbService.Projects.Locations.Clusters.Instances.List(urlString).Do()
|
||||
resp, err := service.Projects.Locations.Clusters.Instances.List(urlString).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing AlloyDB instances: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
alloydbadmin "github.com/googleapis/genai-toolbox/internal/sources/alloydbadmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/alloydb/v1"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const kind string = "alloydb-list-users"
|
||||
@@ -128,21 +126,14 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("invalid 'clusterId' parameter; expected a string")
|
||||
}
|
||||
|
||||
// Get an authenticated HTTP client from the source
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting authorized client: %w", err)
|
||||
}
|
||||
|
||||
// Create a new AlloyDB service client using the authorized client
|
||||
alloydbService, err := alloydb.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating AlloyDB service: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("projects/%s/locations/%s/clusters/%s", projectId, locationId, clusterId)
|
||||
|
||||
resp, err := alloydbService.Projects.Locations.Clusters.Users.List(urlString).Do()
|
||||
resp, err := service.Projects.Locations.Clusters.Users.List(urlString).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error listing AlloyDB users: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user