mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-11 08:28:11 -05:00
Compare commits
14 Commits
v0.23.0
...
cloud-sql-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3117c78af6 | ||
|
|
738cca7c61 | ||
|
|
9f69b3c8d1 | ||
|
|
47f7c9d918 | ||
|
|
07083e3664 | ||
|
|
a74c1b69ac | ||
|
|
79a6c827f7 | ||
|
|
41ea1c488f | ||
|
|
afe6a2c4d2 | ||
|
|
16f9a418c3 | ||
|
|
669e6a3a07 | ||
|
|
96ba178a1d | ||
|
|
cbe8b13936 | ||
|
|
f4ffb2fe27 |
@@ -531,6 +531,24 @@ steps:
|
||||
utility \
|
||||
utility/alloydbwaitforoperation
|
||||
|
||||
- id: "cloud-sql"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud SQL Wait for Operation" \
|
||||
cloudsql \
|
||||
cloudsql
|
||||
|
||||
- id: "tidb"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
|
||||
@@ -54,6 +54,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/bigtable"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouseexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/couchbase"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexsearchaspecttypes"
|
||||
|
||||
7
docs/en/resources/tools/cloudsql/_index.md
Normal file
7
docs/en/resources/tools/cloudsql/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Cloud SQL"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools that work with Cloud SQL Control Plane.
|
||||
---
|
||||
43
docs/en/resources/tools/cloudsql/cloudsqlwaitforoperation.md
Normal file
43
docs/en/resources/tools/cloudsql/cloudsqlwaitforoperation.md
Normal file
@@ -0,0 +1,43 @@
|
||||
---
|
||||
title: "cloud-sql-wait-for-operation"
|
||||
type: docs
|
||||
weight: 10
|
||||
description: >
|
||||
Wait for a long-running Cloud SQL operation to complete.
|
||||
---
|
||||
|
||||
The `cloud-sql-wait-for-operation` tool is a utility tool that waits for a
|
||||
long-running Cloud SQL operation to complete. It does this by polling the Cloud
|
||||
SQL Admin API operation status endpoint until the operation is finished, using
|
||||
exponential backoff.
|
||||
|
||||
{{< notice info >}}
|
||||
This tool is intended for developer assistant workflows with human-in-the-loop
|
||||
and shouldn't be used for production agents.
|
||||
{{< /notice >}}
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
cloudsql-operations-get:
|
||||
kind: cloud-sql-wait-for-operation
|
||||
source: some-http-source
|
||||
description: "This will poll on operations API until the operation is done. For checking operation status we need projectId and operationId. Once instance is created give follow up steps on how to use the variables to bring data plane MCP server up in local and remote setup."
|
||||
delay: 1s
|
||||
maxDelay: 4m
|
||||
multiplier: 2
|
||||
maxRetries: 10
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-sql-wait-for-operation". |
|
||||
| source | string | true | The name of an `http` source to use for authentication. |
|
||||
| description | string | true | A description of the tool. |
|
||||
| delay | duration | false | The initial delay between polling requests (e.g., `3s`). Defaults to 3 seconds. |
|
||||
| maxDelay | duration | false | The maximum delay between polling requests (e.g., `4m`). Defaults to 4 minutes. |
|
||||
| multiplier | float | false | The multiplier for the polling delay. The delay is multiplied by this value after each request. Defaults to 2.0. |
|
||||
| maxRetries | int | false | The maximum number of polling attempts before giving up. Defaults to 10. |
|
||||
@@ -0,0 +1,440 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlwaitforoperation
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
httpsrc "github.com/googleapis/genai-toolbox/internal/sources/http"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const kind string = "cloud-sql-wait-for-operation"
|
||||
|
||||
var cloudSQLConnectionMessageTemplate = `Your Cloud SQL resource is ready.
|
||||
|
||||
To connect, please configure your environment. The method depends on how you are running the toolbox:
|
||||
|
||||
**If running locally via stdio:**
|
||||
Update the MCP server configuration with the following environment variables:
|
||||
` + "```json" + `
|
||||
{
|
||||
"mcpServers": {
|
||||
"cloud-sql-{{.DBType}}": {
|
||||
"command": "./PATH/TO/toolbox",
|
||||
"args": ["--prebuilt","cloud-sql-{{.DBType}}","--stdio"],
|
||||
"env": {
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_PROJECT": "{{.Project}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_REGION": "{{.Region}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_INSTANCE": "{{.Instance}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_DATABASE": "{{.Database}}",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_USER": "<your-user>",
|
||||
"CLOUD_SQL_{{.DBTypeUpper}}_PASSWORD": "<your-password>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
` + "```" + `
|
||||
|
||||
**If running remotely:**
|
||||
For remote deployments, you will need to set the following environment variables in your deployment configuration:
|
||||
` + "```" + `
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_PROJECT={{.Project}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_REGION={{.Region}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_INSTANCE={{.Instance}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_DATABASE={{.Database}}
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_USER=<your-user>
|
||||
CLOUD_SQL_{{.DBTypeUpper}}_PASSWORD=<your-password>
|
||||
` + "```" + `
|
||||
|
||||
Please refer to the official documentation for guidance on deploying the toolbox:
|
||||
- Deploying the Toolbox: https://googleapis.github.io/genai-toolbox/how-to/deploy_toolbox/
|
||||
- Deploying on GKE: https://googleapis.github.io/genai-toolbox/how-to/deploy_gke/
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
// Config defines the configuration for the wait-for-operation tool.
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
|
||||
// Polling configuration
|
||||
Delay string `yaml:"delay"`
|
||||
MaxDelay string `yaml:"maxDelay"`
|
||||
Multiplier float64 `yaml:"multiplier"`
|
||||
MaxRetries int `yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
// ToolConfigKind returns the kind of the tool.
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
// Initialize initializes the tool from the configuration.
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
s, ok := rawS.(*httpsrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `http`", kind)
|
||||
}
|
||||
|
||||
if s.BaseURL != "https://sqladmin.googleapis.com" && !strings.HasPrefix(s.BaseURL, "http://127.0.0.1") {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: baseUrl must be `https://sqladmin.googleapis.com`", kind)
|
||||
}
|
||||
|
||||
allParameters := tools.Parameters{
|
||||
tools.NewStringParameter("project", "The project ID"),
|
||||
tools.NewStringParameter("operation", "The operation ID"),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
inputSchema := allParameters.McpManifest()
|
||||
inputSchema.Required = []string{"project", "operation"}
|
||||
|
||||
mcpManifest := tools.McpManifest{
|
||||
Name: cfg.Name,
|
||||
Description: cfg.Description,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
|
||||
baseURL := cfg.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = "https://sqladmin.googleapis.com"
|
||||
}
|
||||
|
||||
var delay time.Duration
|
||||
if cfg.Delay == "" {
|
||||
delay = 3 * time.Second
|
||||
} else {
|
||||
var err error
|
||||
delay, err = time.ParseDuration(cfg.Delay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for delay: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var maxDelay time.Duration
|
||||
if cfg.MaxDelay == "" {
|
||||
maxDelay = 4 * time.Minute
|
||||
} else {
|
||||
var err error
|
||||
maxDelay, err = time.ParseDuration(cfg.MaxDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid value for maxDelay: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
multiplier := cfg.Multiplier
|
||||
if multiplier == 0 {
|
||||
multiplier = 2.0
|
||||
}
|
||||
|
||||
maxRetries := cfg.MaxRetries
|
||||
if maxRetries == 0 {
|
||||
maxRetries = 10
|
||||
}
|
||||
|
||||
return Tool{
|
||||
Name: cfg.Name,
|
||||
Kind: kind,
|
||||
BaseURL: baseURL,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
Client: s.Client,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
Delay: delay,
|
||||
MaxDelay: maxDelay,
|
||||
Multiplier: multiplier,
|
||||
MaxRetries: maxRetries,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tool represents the wait-for-operation tool.
|
||||
type Tool struct {
|
||||
Name string `yaml:"name"`
|
||||
Kind string `yaml:"kind"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
|
||||
BaseURL string `yaml:"baseURL"`
|
||||
AllParams tools.Parameters `yaml:"allParams"`
|
||||
|
||||
// Polling configuration
|
||||
Delay time.Duration
|
||||
MaxDelay time.Duration
|
||||
Multiplier float64
|
||||
MaxRetries int
|
||||
|
||||
Client *http.Client
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
// Invoke executes the tool's logic.
|
||||
func (t Tool) Invoke(ctx context.Context, params tools.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
project, ok := paramsMap["project"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'project' parameter")
|
||||
}
|
||||
operationID, ok := paramsMap["operation"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing 'operation' parameter")
|
||||
}
|
||||
|
||||
urlString := fmt.Sprintf("%s/v1/projects/%s/operations/%s", t.BaseURL, project, operationID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
delay := t.Delay
|
||||
maxDelay := t.MaxDelay
|
||||
multiplier := t.Multiplier
|
||||
maxRetries := t.MaxRetries
|
||||
retries := 0
|
||||
|
||||
for retries < maxRetries {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("timed out waiting for operation: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, urlString, nil)
|
||||
|
||||
tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/sqlservice.admin")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating token source: %w", err)
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("error making HTTP request during polling: %s, retrying in %v\n", err, delay)
|
||||
} else {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading response body during polling: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code during polling: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err == nil {
|
||||
if val, ok := data["status"]; ok {
|
||||
if fmt.Sprintf("%v", val) == "DONE" {
|
||||
if _, ok := data["error"]; ok {
|
||||
return nil, fmt.Errorf("operation finished with error: %s", string(body))
|
||||
}
|
||||
|
||||
if msg, ok := t.generateCloudSQLConnectionMessage(data); ok {
|
||||
return msg, nil
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Printf("Operation not complete, retrying in %v\n", delay)
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
delay = time.Duration(float64(delay) * multiplier)
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
retries++
|
||||
}
|
||||
return nil, fmt.Errorf("exceeded max retries waiting for operation")
|
||||
}
|
||||
|
||||
// ParseParams parses the parameters for the tool.
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (tools.ParamValues, error) {
|
||||
return tools.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
// Manifest returns the tool's manifest.
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
// McpManifest returns the tool's MCP manifest.
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
// Authorized checks if the tool is authorized.
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Tool) generateCloudSQLConnectionMessage(opResponse map[string]any) (string, bool) {
|
||||
operationType, ok := opResponse["operationType"].(string)
|
||||
if !ok || operationType != "CREATE_DATABASE" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
targetLink, ok := opResponse["targetLink"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
r := regexp.MustCompile(`/projects/([^/]+)/instances/([^/]+)/databases/([^/]+)`)
|
||||
matches := r.FindStringSubmatch(targetLink)
|
||||
if len(matches) < 4 {
|
||||
return "", false
|
||||
}
|
||||
project := matches[1]
|
||||
instance := matches[2]
|
||||
database := matches[3]
|
||||
|
||||
instanceData, err := t.fetchInstanceData(context.Background(), project, instance)
|
||||
if err != nil {
|
||||
fmt.Printf("error fetching instance data: %v\n", err)
|
||||
return "", false
|
||||
}
|
||||
|
||||
region, ok := instanceData["region"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
databaseVersion, ok := instanceData["databaseVersion"].(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var dbType string
|
||||
if strings.Contains(databaseVersion, "POSTGRES") {
|
||||
dbType = "postgres"
|
||||
} else if strings.Contains(databaseVersion, "MYSQL") {
|
||||
dbType = "mysql"
|
||||
} else if strings.Contains(databaseVersion, "SQLSERVER") {
|
||||
dbType = "mssql"
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
|
||||
tmpl, err := template.New("cloud-sql-connection").Parse(cloudSQLConnectionMessageTemplate)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("template parsing error: %v", err), false
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Project string
|
||||
Region string
|
||||
Instance string
|
||||
DBType string
|
||||
DBTypeUpper string
|
||||
Database string
|
||||
}{
|
||||
Project: project,
|
||||
Region: region,
|
||||
Instance: instance,
|
||||
DBType: dbType,
|
||||
DBTypeUpper: strings.ToUpper(dbType),
|
||||
Database: database,
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
if err := tmpl.Execute(&b, data); err != nil {
|
||||
return fmt.Sprintf("template execution error: %v", err), false
|
||||
}
|
||||
|
||||
return b.String(), true
|
||||
}
|
||||
|
||||
func (t Tool) fetchInstanceData(ctx context.Context, project, instance string) (map[string]any, error) {
|
||||
urlString := fmt.Sprintf("%s/v1/projects/%s/instances/%s", t.BaseURL, project, instance)
|
||||
req, _ := http.NewRequest(http.MethodGet, urlString, nil)
|
||||
|
||||
tokenSource, err := google.DefaultTokenSource(ctx, "https://www.googleapis.com/auth/sqlservice.admin")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating token source: %w", err)
|
||||
}
|
||||
token, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token.AccessToken)
|
||||
|
||||
resp, err := t.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code fetching instance data: %d, response body: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsqlwaitforoperation_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
cloudsqlwaitforoperation "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
wait-for-thing:
|
||||
kind: cloud-sql-wait-for-operation
|
||||
source: some-source
|
||||
description: some description
|
||||
delay: 1s
|
||||
maxDelay: 5s
|
||||
multiplier: 1.5
|
||||
maxRetries: 5
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"wait-for-thing": cloudsqlwaitforoperation.Config{
|
||||
Name: "wait-for-thing",
|
||||
Kind: "cloud-sql-wait-for-operation",
|
||||
Source: "some-source",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
Delay: "1s",
|
||||
MaxDelay: "5s",
|
||||
Multiplier: 1.5,
|
||||
MaxRetries: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
284
tests/cloudsql/cloudsql_wait_for_operation_test.go
Normal file
284
tests/cloudsql/cloudsql_wait_for_operation_test.go
Normal file
@@ -0,0 +1,284 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudsql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudsql/cloudsqlwaitforoperation"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudsqlWaitToolKind = "cloud-sql-wait-for-operation"
|
||||
)
|
||||
|
||||
type cloudsqlOperation struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
TargetLink string `json:"targetLink"`
|
||||
OperationType string `json:"operationType"`
|
||||
Error *struct {
|
||||
Errors []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"errors"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type cloudsqlInstance struct {
|
||||
Region string `json:"region"`
|
||||
DatabaseVersion string `json:"databaseVersion"`
|
||||
}
|
||||
|
||||
type cloudsqlHandler struct {
|
||||
mu sync.Mutex
|
||||
operations map[string]*cloudsqlOperation
|
||||
instances map[string]*cloudsqlInstance
|
||||
}
|
||||
|
||||
func (h *cloudsqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if match, _ := regexp.MatchString("/v1/projects/p1/operations/.*", r.URL.Path); match {
|
||||
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
|
||||
opName := parts[len(parts)-1]
|
||||
|
||||
op, ok := h.operations[opName]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if op.Status != "DONE" {
|
||||
op.Status = "DONE"
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(op); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else if match, _ := regexp.MatchString("/v1/projects/p1/instances/.*", r.URL.Path); match {
|
||||
parts := regexp.MustCompile("/").Split(r.URL.Path, -1)
|
||||
instanceName := parts[len(parts)-1]
|
||||
|
||||
instance, ok := h.instances[instanceName]
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(instance); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudSQLWaitToolEndpoints(t *testing.T) {
|
||||
h := &cloudsqlHandler{
|
||||
operations: map[string]*cloudsqlOperation{
|
||||
"op1": {Name: "op1", Status: "PENDING", OperationType: "CREATE_DATABASE"},
|
||||
"op2": {Name: "op2", Status: "PENDING", OperationType: "CREATE_DATABASE", Error: &struct {
|
||||
Errors []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"errors"`
|
||||
}{
|
||||
Errors: []struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}{
|
||||
{Code: "ERROR_CODE", Message: "failed"},
|
||||
},
|
||||
}},
|
||||
"op3": {Name: "op3", Status: "PENDING", OperationType: "CREATE"},
|
||||
},
|
||||
instances: map[string]*cloudsqlInstance{
|
||||
"i1": {Region: "r1", DatabaseVersion: "POSTGRES_13"},
|
||||
},
|
||||
}
|
||||
server := httptest.NewServer(h)
|
||||
defer server.Close()
|
||||
|
||||
h.operations["op1"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i1/databases/d1", server.URL)
|
||||
h.operations["op2"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i2/databases/d2", server.URL)
|
||||
h.operations["op3"].TargetLink = fmt.Sprintf("%s/v1/projects/p1/instances/i1", server.URL)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var args []string
|
||||
|
||||
toolsFile := getCloudSQLWaitToolsConfig(server.URL)
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
name string
|
||||
toolName string
|
||||
body string
|
||||
want string
|
||||
expectError bool
|
||||
wantSubstring bool
|
||||
}{
|
||||
{
|
||||
name: "successful operation",
|
||||
toolName: "wait-for-op1",
|
||||
body: `{"project": "p1", "operation": "op1"}`,
|
||||
want: "Your Cloud SQL resource is ready",
|
||||
wantSubstring: true,
|
||||
},
|
||||
{
|
||||
name: "failed operation",
|
||||
toolName: "wait-for-op2",
|
||||
body: `{"project": "p1", "operation": "op2"}`,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "non-database create operation",
|
||||
toolName: "wait-for-op3",
|
||||
body: `{"project": "p1", "operation": "op3"}`,
|
||||
want: `{"name":"op3","status":"DONE","targetLink":"` + h.operations["op3"].TargetLink + `","operationType":"CREATE"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
api := fmt.Sprintf("http://127.0.0.1:5000/api/tool/%s/invoke", tc.toolName)
|
||||
req, err := http.NewRequest(http.MethodPost, api, bytes.NewBufferString(tc.body))
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create request: %s", err)
|
||||
}
|
||||
req.Header.Add("Content-type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to send request: %s", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.expectError {
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
t.Fatal("expected error but got status 200")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("response status code is not 200, got %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
if tc.wantSubstring {
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Contains([]byte(result.Result), []byte(tc.want)) {
|
||||
t.Fatalf("unexpected result: got %q, want substring %q", result.Result, tc.want)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Result string `json:"result"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
var tempString string
|
||||
if err := json.Unmarshal([]byte(result.Result), &tempString); err != nil {
|
||||
t.Fatalf("failed to unmarshal outer JSON string: %v", err)
|
||||
}
|
||||
|
||||
var got, want map[string]any
|
||||
if err := json.Unmarshal([]byte(tempString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal inner JSON object: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(tc.want), &want); err != nil {
|
||||
t.Fatalf("failed to unmarshal want: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("unexpected result: got %+v, want %+v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func getCloudSQLWaitToolsConfig(baseURL string) map[string]any {
|
||||
return map[string]any{
|
||||
"sources": map[string]any{
|
||||
"test-source": map[string]any{
|
||||
"kind": "http",
|
||||
"baseUrl": baseURL,
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"wait-for-op1": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op1",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
"wait-for-op2": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op2",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
"wait-for-op3": map[string]any{
|
||||
"kind": cloudsqlWaitToolKind,
|
||||
"source": "test-source",
|
||||
"description": "wait for op3",
|
||||
"baseURL": baseURL,
|
||||
"authRequired": []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user