feat(prebuilt/alloydb)!: Add bearer token support for alloydb-wait-for-operation (#1183)

## Description
---
The *alloydb-wait-for-operation* tool now automatically obtains a Google
Cloud Platform OAuth2 token and uses it as a bearer token to
authenticate requests to the AlloyDB Admin API. This ensures secure and
proper communication with the API.
<img width="3840" height="2084" alt="image"
src="https://github.com/user-attachments/assets/e756255f-83f9-4719-8d8b-596a628ca1e3"
/>
## PR Checklist
---
> Thank you for opening a Pull Request! Before submitting your PR, there
are a
> few things you can do to make sure it goes smoothly:
- [x] Make sure you reviewed

[CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md)
- [x] Make sure to open an issue as a

[bug/issue](https://github.com/googleapis/langchain-google-alloydb-pg-python/issues/new/choose)
before writing your code! That way we can discuss the change, evaluate
  designs, and agree on the general idea
- [x] Ensure the tests and linter pass
- [x] Code coverage does not decrease (if any source code was changed)
- [x] Appropriate docs were updated (if necessary)
- [x] Make sure to add `!` if this involve a breaking change

🛠️ Fixes
https://github.com/googleapis/langchain-google-alloydb-pg-python/issues/456
This commit is contained in:
prernakakkar-google
2025-08-21 06:28:17 +00:00
committed by GitHub
parent 2bcac71647
commit f8f09818c7
5 changed files with 70 additions and 76 deletions

View File

@@ -16,21 +16,17 @@ This tool is intended for developer assistant workflows with human-in-the-loop
and shouldn't be used for production agents.
{{< /notice >}}
{{< notice info >}}
This tool does not have a `source` and authenticates using the environment's
[Application Default Credentials](https://cloud.google.com/docs/authentication/application-default-credentials).
{{< /notice >}}
## Example
```yaml
sources:
alloydb-api-source:
kind: http
baseUrl: https://alloydb.googleapis.com
headers:
Authorization: Bearer ${API_KEY}
Content-Type: application/json
tools:
alloydb-operations-get:
kind: alloydb-wait-for-operation
source: alloydb-api-source
description: "This will poll on operations API until the operation is done. For checking operation status we need projectId, locationID and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
delay: 1s
maxDelay: 4m
@@ -43,7 +39,6 @@ tools:
| **field** | **type** | **required** | **description** |
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
| kind | string | true | Must be "alloydb-wait-for-operation". |
| source | string | true | Name of the source the HTTP request should be sent to. |
| description | string | true | A description of the tool. |
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |

View File

@@ -50,7 +50,6 @@ tools:
description: "The name for the initial superuser. If not provided, it defaults to 'postgres'. The initial database will always be named 'postgres'."
alloydb-operations-get:
kind: alloydb-wait-for-operation
source: alloydb-api-source
description: "This will poll on operations API until the operation is done. For checking operation status we need projectId, locationID and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
delay: 1s
maxDelay: 4m

View File

@@ -24,12 +24,10 @@ import (
"text/template"
"time"
"maps"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
"github.com/googleapis/genai-toolbox/internal/tools"
"golang.org/x/oauth2/google"
)
const kind string = "alloydb-wait-for-operation"
@@ -93,12 +91,11 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
// Config defines the configuration for the wait-for-operation tool.
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
Headers map[string]string `yaml:"headers"`
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Description string `yaml:"description" validate:"required"`
AuthRequired []string `yaml:"authRequired"`
BaseURL string `yaml:"baseURL"`
// Polling configuration
Delay string `yaml:"delay"`
@@ -117,15 +114,6 @@ func (cfg Config) ToolConfigKind() string {
// Initialize initializes the tool from the configuration.
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
s, ok := srcs[cfg.Source].(*httpsrc.Source)
if !ok {
return nil, fmt.Errorf("invalid or missing source for %q tool: source kind must be `http`", kind)
}
combinedHeaders := make(map[string]string)
maps.Copy(combinedHeaders, s.DefaultHeaders)
maps.Copy(combinedHeaders, cfg.Headers)
allParameters := tools.Parameters{
tools.NewStringParameter("project", "The project ID"),
tools.NewStringParameter("location", "The location ID"),
@@ -142,6 +130,11 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
InputSchema: inputSchema,
}
baseURL := cfg.BaseURL
if baseURL == "" {
baseURL = "https://alloydb.googleapis.com"
}
var delay time.Duration
if cfg.Delay == "" {
delay = 3 * time.Second
@@ -174,13 +167,12 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
maxRetries = 10
}
return &Tool{
return Tool{
Name: cfg.Name,
Kind: kind,
BaseURL: s.BaseURL,
Headers: combinedHeaders,
BaseURL: baseURL,
AuthRequired: cfg.AuthRequired,
Client: s.Client,
Client: &http.Client{},
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
@@ -198,9 +190,8 @@ type Tool struct {
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
BaseURL string `yaml:"baseURL"`
Headers map[string]string `yaml:"headers"`
AllParams tools.Parameters `yaml:"allParams"`
BaseURL string `yaml:"baseURL"`
AllParams tools.Parameters `yaml:"allParams"`
// Polling configuration
Delay time.Duration
@@ -214,7 +205,7 @@ type Tool struct {
}
// Invoke executes the tool's logic.
func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error) {
paramsMap := params.AsMap()
project, ok := paramsMap["project"].(string)
@@ -251,9 +242,19 @@ func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error
req, _ := http.NewRequest(http.MethodGet, urlString, nil)
for k, v := range t.Headers {
req.Header.Set(k, v)
// This request is authenticated using Google Application Default Credentials (ADC).
// The ADC are discovered automatically from the environment.
// For more details, see: https://cloud.google.com/docs/authentication/application-default-credentials
// The "cloud-platform" scope provides broad access to Google Cloud services, there is no specific scope for AlloyDB.
tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("error creating token source: %w", err)
}
token, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("error retrieving token: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
resp, err := t.Client.Do(req)
if err != nil {
@@ -280,7 +281,6 @@ func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error
if msg, ok := t.generateAlloyDBConnectionMessage(data); ok {
return msg, nil
}
return string(body), nil
}
}
@@ -298,7 +298,7 @@ func (t *Tool) Invoke(ctx context.Context, params tools.ParamValues) (any, error
return nil, fmt.Errorf("exceeded max retries waiting for operation")
}
func (t *Tool) generateAlloyDBConnectionMessage(opResponse map[string]any) (string, bool) {
func (t Tool) generateAlloyDBConnectionMessage(opResponse map[string]any) (string, bool) {
responseData, ok := opResponse["response"].(map[string]any)
if !ok {
return "", false
@@ -355,21 +355,21 @@ func (t *Tool) generateAlloyDBConnectionMessage(opResponse map[string]any) (stri
}
// ParseParams parses the parameters for the tool.
func (t *Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
return tools.ParseParams(t.AllParams, data, claims)
}
// Manifest returns the tool's manifest.
func (t *Tool) Manifest() tools.Manifest {
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
// McpManifest returns the tool's MCP manifest.
func (t *Tool) McpManifest() tools.McpManifest {
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
// Authorized checks if the tool is authorized.
func (t *Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return true
}

View File

@@ -40,7 +40,6 @@ func TestParseFromYaml(t *testing.T) {
tools:
wait-for-thing:
kind: alloydb-wait-for-operation
source: my-source
description: some description
delay: 1s
maxDelay: 5s
@@ -51,7 +50,6 @@ func TestParseFromYaml(t *testing.T) {
"wait-for-thing": alloydbwaitforoperation.Config{
Name: "wait-for-thing",
Kind: "alloydb-wait-for-operation",
Source: "my-source",
Description: "some description",
AuthRequired: []string{},
Delay: "1s",

View File

@@ -33,15 +33,14 @@ import (
)
var (
httpSourceKind = "http"
waitToolKind = "alloydb-wait-for-operation"
waitToolKind = "alloydb-wait-for-operation"
)
type operation struct {
Name string `json:"name"`
Done bool `json:"done"`
Result string `json:"result,omitempty"`
Error *struct {
Name string `json:"name"`
Done bool `json:"done"`
Response any `json:"response,omitempty"`
Error *struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"error,omitempty"`
@@ -80,7 +79,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func TestWaitToolEndpoints(t *testing.T) {
h := &handler{
operations: map[string]*operation{
"op1": {Name: "op1", Done: false, Result: "success"},
"op1": {Name: "op1", Done: false, Response: "success"},
"op2": {Name: "op2", Done: false, Error: &struct {
Code int `json:"code"`
Message string `json:"message"`
@@ -90,16 +89,12 @@ func TestWaitToolEndpoints(t *testing.T) {
server := httptest.NewServer(h)
defer server.Close()
sourceConfig := map[string]any{
"kind": httpSourceKind,
"baseUrl": server.URL,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var args []string
toolsFile := getWaitToolsConfig(sourceConfig)
toolsFile := getWaitToolsConfig(server.URL)
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
@@ -115,17 +110,18 @@ func TestWaitToolEndpoints(t *testing.T) {
}
tcs := []struct {
name string
toolName string
body string
want string
expectError bool
name string
toolName string
body string
want string
expectError bool
wantSubstring bool
}{
{
name: "successful operation",
toolName: "wait-for-op1",
body: `{"project": "p1", "location": "l1", "operation_id": "op1"}`,
want: `{"name":"op1","done":true,"result":"success"}`,
want: `{"name":"op1","done":true,"response":"success"}`,
},
{
name: "failed operation",
@@ -168,6 +164,13 @@ func TestWaitToolEndpoints(t *testing.T) {
t.Fatalf("failed to decode response: %v", err)
}
if tc.wantSubstring {
if !bytes.Contains([]byte(result.Result), []byte(tc.want)) {
t.Fatalf("unexpected result: got %q, want substring %q", result.Result, tc.want)
}
return
}
// The result is a JSON-encoded string, so we need to unmarshal it twice.
var unmarshaledResult string
if err := json.Unmarshal([]byte(result.Result), &unmarshaledResult); err != nil {
@@ -189,21 +192,20 @@ func TestWaitToolEndpoints(t *testing.T) {
}
}
func getWaitToolsConfig(sourceConfig map[string]any) map[string]any {
func getWaitToolsConfig(baseURL string) map[string]any {
return map[string]any{
"sources": map[string]any{
"my-instance": sourceConfig,
},
"tools": map[string]any{
"wait-for-op1": map[string]any{
"kind": waitToolKind,
"description": "wait for op1",
"source": "my-instance",
"kind": waitToolKind,
"description": "wait for op1",
"baseURL": baseURL,
"authRequired": []string{},
},
"wait-for-op2": map[string]any{
"kind": waitToolKind,
"description": "wait for op2",
"source": "my-instance",
"kind": waitToolKind,
"description": "wait for op2",
"baseURL": baseURL,
"authRequired": []string{},
},
},
}