feat(source/cloud-sql-admin): Add User agent and attach sqldmin in cloud-sql-admin source. (#1441)

## Description

---
1. This change introduces a userAgentRoundTripper that correctly
prepends our custom user agent to the existing User-Agent header
2. Moves sqladmin client to source.
3. Updated cloudsql tools for above support.
4. Add test cases to validate User agent.

## PR Checklist

---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:

- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [ ] Ensure the tests and linter pass
- [ ] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [ ] Make sure to add `!` if this involve a breaking change

🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
prernakakkar-google
2025-09-15 16:31:38 +00:00
committed by GitHub
parent 7d384dc28f
commit 56b6574fc2
11 changed files with 118 additions and 97 deletions

View File

@@ -9,7 +9,7 @@ description: >
The `cloud-sql-get-instance` tool retrieves a Cloud SQL instance resource using the Cloud SQL Admin API.
{{< notice info >}}
This tool uses a `source` of kind `cloud-sql-admin`. The source automatically generates a bearer token on behalf of the user with the `https://www.googleapis.com/auth/sqlservice.admin` scope to authenticate requests.
This tool uses a `source` of kind `cloud-sql-admin`.
{{< /notice >}}
## Example
@@ -18,14 +18,14 @@ This tool uses a `source` of kind `cloud-sql-admin`. The source automatically ge
tools:
get-sql-instance:
kind: cloud-sql-get-instance
description: "Get a Cloud SQL instance resource."
source: my-cloud-sql-source
source: my-cloud-sql-admin-source
description: "Gets a particular cloud sql instance."
```
## Reference
| **field** | **type** | **required** | **description** |
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
| kind | string | true | Must be "cloud-sql-get-instance". |
| description | string | true | A description of the tool. |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| **field** | **type** | **required** | **description** |
| ----------- | :------: | :----------: | ------------------------------------------------ |
| kind | string | true | Must be "cloud-sql-get-instance". |
| source | string | true | The name of the `cloud-sql-admin` source to use. |
| description | string | false | A description of the tool. |

View File

@@ -32,8 +32,8 @@ tools:
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
| kind | string | true | Must be "cloud-sql-wait-for-operation". |
| source | string | true | The name of a `cloud-sql-admin` source to use for authentication. |
| description | string | true | A description of the tool. |
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
| description | string | false | A description of the tool. |
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |

View File

@@ -24,11 +24,32 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1"
)
const SourceKind string = "cloud-sql-admin"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", rt.userAgent+" "+ua)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
@@ -65,22 +86,36 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
var client *http.Client
if r.UseClientOAuth {
client = nil
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
}
} else {
// Use Application Default Credentials
creds, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
client = oauth2.NewClient(ctx, creds.TokenSource)
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
client = baseClient
}
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
s := &Source{
Name: r.Name,
Kind: SourceKind,
BaseURL: "https://sqladmin.googleapis.com",
Client: client,
UserAgent: ua,
Service: service,
UseClientOAuth: r.UseClientOAuth,
}
return s, nil
@@ -92,8 +127,7 @@ type Source struct {
Name string `yaml:"name"`
Kind string `yaml:"kind"`
BaseURL string
Client *http.Client
UserAgent string
Service *sqladmin.Service
UseClientOAuth bool
}
@@ -101,15 +135,17 @@ func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
return service, nil
}
return s.Client, nil
return s.Service, nil
}
func (s *Source) UseClientAuthorization() bool {

View File

@@ -22,7 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1"
)
@@ -135,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
iamUser, _ := paramsMap["iamUser"].(bool)
user := &sqladmin.User{
user := sqladmin.User{
Name: name,
}
@@ -150,19 +149,12 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
user.Password = password
}
client, err := t.Source.GetClient(ctx, string(accessToken))
service, err := t.Source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
service.UserAgent = t.Source.UserAgent
resp, err := service.Users.Insert(project, instance, user).Do()
resp, err := service.Users.Insert(project, instance, &user).Do()
if err != nil {
return nil, fmt.Errorf("error creating user: %w", err)
}

View File

@@ -22,8 +22,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-get-instance"
@@ -46,7 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Description string `yaml:"description" validate:"required"`
Description string `yaml:"description"`
Source string `yaml:"source" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
}
@@ -80,9 +78,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
inputSchema := allParameters.McpManifest()
inputSchema.Required = []string{"projectId", "instanceId"}
description := cfg.Description
if description == "" {
description = "Gets a particular cloud sql instance."
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
Description: description,
InputSchema: inputSchema,
}
@@ -123,17 +126,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("missing 'instanceId' parameter")
}
client, err := t.Source.GetClient(ctx, string(accessToken))
service, err := t.Source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
service.UserAgent = t.Source.UserAgent
resp, err := service.Instances.Get(projectId, instanceId).Do()
if err != nil {
return nil, fmt.Errorf("error getting instance: %w", err)

View File

@@ -16,10 +16,7 @@ package cloudsqllistinstances
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -123,48 +120,34 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("missing 'project' parameter")
}
client, err := t.source.GetClient(ctx, string(accessToken))
service, err := t.source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
urlString := fmt.Sprintf("%s/v1/projects/%s/instances", t.source.BaseURL, project)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil)
resp, err := service.Instances.List(project).Do()
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
return nil, fmt.Errorf("error listing instances: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making HTTP request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
}
var v struct {
Items []struct {
Name string `json:"name"`
InstanceType string `json:"instanceType"`
} `json:"items"`
}
if err := json.Unmarshal(body, &v); err != nil {
return nil, fmt.Errorf("error unmarshaling response body: %w", err)
}
if v.Items == nil {
if resp.Items == nil {
return []any{}, nil
}
return v.Items, nil
type instanceInfo struct {
Name string `json:"name"`
InstanceType string `json:"instanceType"`
}
var instances []instanceInfo
for _, item := range resp.Items {
instances = append(instances, instanceInfo{
Name: item.Name,
InstanceType: item.InstanceType,
})
}
return instances, nil
}
// ParseParams parses the parameters for the tool.

View File

@@ -27,8 +27,6 @@ import (
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
"github.com/googleapis/genai-toolbox/internal/tools"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1"
)
const kind string = "cloud-sql-wait-for-operation"
@@ -93,7 +91,7 @@ type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
BaseURL string `yaml:"baseURL"`
@@ -133,9 +131,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
inputSchema := allParameters.McpManifest()
inputSchema.Required = []string{"project", "operation"}
description := cfg.Description
if description == "" {
description = "This will poll on operations API until the operation is done. For checking operation status we need projectId and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
}
mcpManifest := tools.McpManifest{
Name: cfg.Name,
Description: cfg.Description,
Description: description,
InputSchema: inputSchema,
}
@@ -219,16 +222,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
return nil, fmt.Errorf("missing 'operation' parameter")
}
client, err := t.Source.GetClient(ctx, string(accessToken))
service, err := t.Source.GetService(ctx, string(accessToken))
if err != nil {
return nil, err
}
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
defer cancel()
@@ -389,14 +387,10 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
}
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
client, err := t.Source.GetClient(ctx, "")
service, err := t.Source.GetService(ctx, "")
if err != nil {
return nil, err
}
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
if err != nil {
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
}
resp, err := service.Instances.Get(project, instance).Do()
if err != nil {

View File

@@ -62,6 +62,10 @@ type masterCreateUserHandler struct {
}
func (h *masterCreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
h.t.Errorf("User-Agent header not found")
}
var body userCreateRequest
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
h.t.Fatalf("failed to decode request body: %v", err)

View File

@@ -59,12 +59,17 @@ type instance struct {
type handler struct {
mu sync.Mutex
instances map[string]*instance
t *testing.T
}
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.mu.Lock()
defer h.mu.Unlock()
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
h.t.Errorf("User-Agent header not found")
}
if !strings.HasPrefix(r.URL.Path, "/v1/projects/") {
http.Error(w, "unexpected path", http.StatusBadRequest)
return
@@ -92,6 +97,7 @@ func TestGetInstancesToolEndpoints(t *testing.T) {
instances: map[string]*instance{
"instance-1": {Name: "instance-1", Kind: "sql#instance"},
},
t: t,
}
server := httptest.NewServer(h)
defer server.Close()

View File

@@ -49,6 +49,9 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
func TestListInstance(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
t.Errorf("User-Agent header not found")
}
if r.URL.Path != "/v1/projects/test-project/instances" {
http.Error(w, fmt.Sprintf("unexpected path: got %q", r.URL.Path), http.StatusBadRequest)
return

View File

@@ -75,12 +75,17 @@ type cloudsqlHandler struct {
mu sync.Mutex
operations map[string]*cloudsqlOperation
instances map[string]*cloudsqlInstance
t *testing.T
}
func (h *cloudsqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.mu.Lock()
defer h.mu.Unlock()
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
h.t.Errorf("User-Agent header not found")
}
if match, _ := regexp.MatchString("/v1/projects/p1/operations/.*", r.URL.Path); match {
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
opName := parts[len(parts)-1]
@@ -139,6 +144,7 @@ func TestCloudSQLWaitToolEndpoints(t *testing.T) {
instances: map[string]*cloudsqlInstance{
"i1": {Region: "r1", DatabaseVersion: "POSTGRES_13"},
},
t: t,
}
server := httptest.NewServer(h)
defer server.Close()