mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 16:38:15 -05:00
feat(source/cloud-sql-admin): Add User agent and attach sqldmin in cloud-sql-admin source. (#1441)
## Description --- 1. This change introduces a userAgentRoundTripper that correctly prepends our custom user agent to the existing User-Agent header 2. Moves sqladmin client to source. 3. Updated cloudsql tools for above support. 4. Add test cases to validate User agent. ## PR Checklist --- > Thank you for opening a Pull Request! Before submitting your PR, there are a > few things you can do to make sure it goes smoothly: - [x] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) - [x] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [x] Appropriate docs were updated (if necessary) - [ ] Make sure to add `!` if this involve a breaking change 🛠️ Fixes #<issue_number_goes_here>
This commit is contained in:
committed by
GitHub
parent
7d384dc28f
commit
56b6574fc2
@@ -9,7 +9,7 @@ description: >
|
||||
The `cloud-sql-get-instance` tool retrieves a Cloud SQL instance resource using the Cloud SQL Admin API.
|
||||
|
||||
{{< notice info >}}
|
||||
This tool uses a `source` of kind `cloud-sql-admin`. The source automatically generates a bearer token on behalf of the user with the `https://www.googleapis.com/auth/sqlservice.admin` scope to authenticate requests.
|
||||
This tool uses a `source` of kind `cloud-sql-admin`.
|
||||
{{< /notice >}}
|
||||
|
||||
## Example
|
||||
@@ -18,14 +18,14 @@ This tool uses a `source` of kind `cloud-sql-admin`. The source automatically ge
|
||||
tools:
|
||||
get-sql-instance:
|
||||
kind: cloud-sql-get-instance
|
||||
description: "Get a Cloud SQL instance resource."
|
||||
source: my-cloud-sql-source
|
||||
source: my-cloud-sql-admin-source
|
||||
description: "Gets a particular cloud sql instance."
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-get-instance". |
|
||||
| description | string | true | A description of the tool. |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------- | :------: | :----------: | ------------------------------------------------ |
|
||||
| kind | string | true | Must be "cloud-sql-get-instance". |
|
||||
| source | string | true | The name of the `cloud-sql-admin` source to use. |
|
||||
| description | string | false | A description of the tool. |
|
||||
|
||||
@@ -32,8 +32,8 @@ tools:
|
||||
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-wait-for-operation". |
|
||||
| source | string | true | The name of a `cloud-sql-admin` source to use for authentication. |
|
||||
| description | string | true | A description of the tool. |
|
||||
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
|
||||
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
|
||||
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
|
||||
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
|
||||
| description | string | false | A description of the tool. |
|
||||
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
|
||||
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
|
||||
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
|
||||
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
|
||||
|
||||
@@ -24,11 +24,32 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/api/option"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-sql-admin"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent+" "+ua)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
@@ -65,22 +86,36 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = nil
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
creds, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
client = oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Name: r.Name,
|
||||
Kind: SourceKind,
|
||||
BaseURL: "https://sqladmin.googleapis.com",
|
||||
Client: client,
|
||||
UserAgent: ua,
|
||||
Service: service,
|
||||
UseClientOAuth: r.UseClientOAuth,
|
||||
}
|
||||
return s, nil
|
||||
@@ -92,8 +127,7 @@ type Source struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
BaseURL string
|
||||
Client *http.Client
|
||||
UserAgent string
|
||||
Service *sqladmin.Service
|
||||
UseClientOAuth bool
|
||||
}
|
||||
|
||||
@@ -101,15 +135,17 @@ func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
func (s *Source) GetService(ctx context.Context, accessToken string) (*sqladmin.Service, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
return oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)), nil
|
||||
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
return s.Service, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/option"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
@@ -135,7 +134,7 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
|
||||
iamUser, _ := paramsMap["iamUser"].(bool)
|
||||
|
||||
user := &sqladmin.User{
|
||||
user := sqladmin.User{
|
||||
Name: name,
|
||||
}
|
||||
|
||||
@@ -150,19 +149,12 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
user.Password = password
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
|
||||
service.UserAgent = t.Source.UserAgent
|
||||
|
||||
resp, err := service.Users.Insert(project, instance, user).Do()
|
||||
resp, err := service.Users.Insert(project, instance, &user).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating user: %w", err)
|
||||
}
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/option"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-get-instance"
|
||||
@@ -46,7 +44,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
@@ -80,9 +78,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"projectId", "instanceId"}
|
||||
|
||||
description := cfg.Description
|
||||
if description == "" {
|
||||
description = "Gets a particular cloud sql instance."
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
@@ -123,17 +126,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("missing 'instanceId' parameter")
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
service.UserAgent = t.Source.UserAgent
|
||||
|
||||
resp, err := service.Instances.Get(projectId, instanceId).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting instance: %w", err)
|
||||
|
||||
@@ -16,10 +16,7 @@ package cloudsqllistinstances
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -123,48 +120,34 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
}
|
||||
|
||||
client, err := t.source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("%s/v1/projects/%s/instances", t.source.BaseURL, project)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil)
|
||||
resp, err := service.Instances.List(project).Do()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating request: %w", err)
|
||||
return nil, fmt.Errorf("error listing instances: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error making HTTP request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var v struct {
|
||||
Items []struct {
|
||||
Name string `json:"name"`
|
||||
InstanceType string `json:"instanceType"`
|
||||
} `json:"items"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &v); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling response body: %w", err)
|
||||
}
|
||||
|
||||
if v.Items == nil {
|
||||
if resp.Items == nil {
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
return v.Items, nil
|
||||
type instanceInfo struct {
|
||||
Name string `json:"name"`
|
||||
InstanceType string `json:"instanceType"`
|
||||
}
|
||||
|
||||
var instances []instanceInfo
|
||||
for _, item := range resp.Items {
|
||||
instances = append(instances, instanceInfo{
|
||||
Name: item.Name,
|
||||
InstanceType: item.InstanceType,
|
||||
})
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
|
||||
@@ -27,8 +27,6 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqladmin"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"google.golang.org/api/option"
|
||||
sqladmin "google.golang.org/api/sqladmin/v1"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-wait-for-operation"
|
||||
@@ -93,7 +91,7 @@ type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
|
||||
@@ -133,9 +131,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"project", "operation"}
|
||||
|
||||
description := cfg.Description
|
||||
if description == "" {
|
||||
description = "This will poll on operations API until the operation is done. For checking operation status we need projectId and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
|
||||
}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
Description: description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
@@ -219,16 +222,11 @@ func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, string(accessToken))
|
||||
service, err := t.Source.GetService(ctx, string(accessToken))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
@@ -389,14 +387,10 @@ func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (stri
|
||||
}
|
||||
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
|
||||
client, err := t.Source.GetClient(ctx, "")
|
||||
service, err := t.Source.GetService(ctx, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service, err := sqladmin.NewService(ctx, option.WithHTTPClient(client))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating new sqladmin service: %w", err)
|
||||
}
|
||||
|
||||
resp, err := service.Instances.Get(project, instance).Do()
|
||||
if err != nil {
|
||||
|
||||
@@ -62,6 +62,10 @@ type masterCreateUserHandler struct {
|
||||
}
|
||||
|
||||
func (h *masterCreateUserHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
h.t.Errorf("User-Agent header not found")
|
||||
}
|
||||
|
||||
var body userCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
||||
h.t.Fatalf("failed to decode request body: %v", err)
|
||||
|
||||
@@ -59,12 +59,17 @@ type instance struct {
|
||||
type handler struct {
|
||||
mu sync.Mutex
|
||||
instances map[string]*instance
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
h.t.Errorf("User-Agent header not found")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/projects/") {
|
||||
http.Error(w, "unexpected path", http.StatusBadRequest)
|
||||
return
|
||||
@@ -92,6 +97,7 @@ func TestGetInstancesToolEndpoints(t *testing.T) {
|
||||
instances: map[string]*instance{
|
||||
"instance-1": {Name: "instance-1", Kind: "sql#instance"},
|
||||
},
|
||||
t: t,
|
||||
}
|
||||
server := httptest.NewServer(h)
|
||||
defer server.Close()
|
||||
|
||||
@@ -49,6 +49,9 @@ func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
|
||||
func TestListInstance(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
t.Errorf("User-Agent header not found")
|
||||
}
|
||||
if r.URL.Path != "/v1/projects/test-project/instances" {
|
||||
http.Error(w, fmt.Sprintf("unexpected path: got %q", r.URL.Path), http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@@ -75,12 +75,17 @@ type cloudsqlHandler struct {
|
||||
mu sync.Mutex
|
||||
operations map[string]*cloudsqlOperation
|
||||
instances map[string]*cloudsqlInstance
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *cloudsqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
h.t.Errorf("User-Agent header not found")
|
||||
}
|
||||
|
||||
if match, _ := regexp.MatchString("/v1/projects/p1/operations/.*", r.URL.Path); match {
|
||||
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
|
||||
opName := parts[len(parts)-1]
|
||||
@@ -139,6 +144,7 @@ func TestCloudSQLWaitToolEndpoints(t *testing.T) {
|
||||
instances: map[string]*cloudsqlInstance{
|
||||
"i1": {Region: "r1", DatabaseVersion: "POSTGRES_13"},
|
||||
},
|
||||
t: t,
|
||||
}
|
||||
server := httptest.NewServer(h)
|
||||
defer server.Close()
|
||||
|
||||
Reference in New Issue
Block a user