Merge branch 'main' into fix_dashboard_filter

This commit is contained in:
Dr. Strangelove
2025-12-18 13:07:11 -05:00
committed by GitHub
14 changed files with 1471 additions and 6 deletions

View File

@@ -212,6 +212,26 @@ steps:
bigquery \
bigquery
- id: "cloud-gda"
name: golang:1
waitFor: ["compile-test-binary"]
entrypoint: /bin/bash
env:
- "GOPATH=/gopath"
- "CLOUD_GDA_PROJECT=$PROJECT_ID"
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
secretEnv: ["CLIENT_ID"]
volumes:
- name: "go"
path: "/gopath"
args:
- -c
- |
.ci/test_with_coverage.sh \
"Cloud Gemini Data Analytics" \
cloudgda \
cloudgda
- id: "dataplex"
name: golang:1
waitFor: ["compile-test-binary"]

View File

@@ -73,6 +73,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"

View File

@@ -0,0 +1,40 @@
---
title: "Gemini Data Analytics"
type: docs
weight: 1
description: >
A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API.
aliases:
- /resources/sources/cloud-gemini-data-analytics
---
## About
The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API.
Authentication can be handled in two ways:
1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method.
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls.
## Example
```yaml
sources:
my-gda-source:
kind: cloud-gemini-data-analytics
projectId: my-project-id
my-oauth-gda-source:
kind: cloud-gemini-data-analytics
projectId: my-project-id
useClientOAuth: true
```
## Reference
| **field** | **type** | **required** | **description** |
| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| kind | string | true | Must be "cloud-gemini-data-analytics". |
| projectId | string | true | The Google Cloud Project ID where the API is enabled. |
| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. |

View File

@@ -0,0 +1,7 @@
---
title: "Gemini Data Analytics"
type: docs
weight: 1
description: >
Tools for Gemini Data Analytics.
---

View File

@@ -0,0 +1,92 @@
---
title: "Gemini Data Analytics QueryData"
type: docs
weight: 1
description: >
A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API.
aliases:
- /resources/tools/cloud-gemini-data-analytics-query
---
## About
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
## Example
```yaml
tools:
my-gda-query-tool:
kind: cloud-gemini-data-analytics-query
source: my-gda-source
description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations."
location: ${your_database_location}
context:
datasourceReferences:
cloudSqlReference:
databaseReference:
projectId: "${your_project_id}"
region: "${your_database_instance_region}"
instanceId: "${your_database_instance_id}"
databaseId: "${your_database_name}"
engine: "POSTGRESQL"
agentContextReference:
contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id}
generationOptions:
generateQueryResult: true
generateNaturalLanguageAnswer: true
generateExplanation: true
generateDisambiguationQuestion: true
```
### Usage Flow
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
**Example Input Prompt:**
```text
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
```
**Example API Response:**
```json
{
"generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'",
"intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.",
"naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.",
"queryResult": {
"columns": [
{
"type": "INT64"
}
],
"rows": [
{
"values": [
{
"value": "84"
}
]
}
],
"totalRowCount": "1"
}
}
```
## Reference
| **field** | **type** | **required** | **description** |
| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| kind | string | true | Must be "cloud-gemini-data-analytics-query". |
| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. |
| description | string | true | A description of the tool's purpose. |
| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. |
| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. |
| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. |

4
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.2.1
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
@@ -181,7 +181,7 @@ require (
golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.38.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
google.golang.org/grpc v1.76.0 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect

8
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ=
cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
@@ -1990,8 +1990,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0=
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0=
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM=
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4=
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=

View File

@@ -0,0 +1,154 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
import (
"context"
"fmt"
"net/http"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
type userAgentRoundTripper struct {
userAgent string
next http.RoundTripper
}
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
ua := newReq.Header.Get("User-Agent")
if ua == "" {
newReq.Header.Set("User-Agent", rt.userAgent)
} else {
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
}
return rt.next.RoundTrip(&newReq)
}
// validate interface
var _ sources.SourceConfig = Config{}
func init() {
if !sources.Register(SourceKind, newConfig) {
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
ProjectID string `yaml:"projectId" validate:"required"`
UseClientOAuth bool `yaml:"useClientOAuth"`
}
func (r Config) SourceConfigKind() string {
return SourceKind
}
// Initialize initializes a Gemini Data Analytics Source instance.
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
ua, err := util.UserAgentFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: &userAgentRoundTripper{
userAgent: ua,
next: http.DefaultTransport,
},
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = &userAgentRoundTripper{
userAgent: ua,
next: baseClient.Transport,
}
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
return s, nil
}
var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
BaseURL string
userAgent string
}
func (s *Source) SourceKind() string {
return SourceKind
}
func (s *Source) ToConfig() sources.SourceConfig {
return s.Config
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = &userAgentRoundTripper{
userAgent: s.userAgent,
next: baseClient.Transport,
}
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}

View File

@@ -0,0 +1,213 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda_test
import (
"context"
"os"
"path/filepath"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"go.opentelemetry.io/otel/trace/noop"
)
func TestParseFromYamlCloudGDA(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
want server.SourceConfigs
}{
{
desc: "basic example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: test-project-id
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
ProjectID: "test-project-id",
UseClientOAuth: false,
},
},
},
{
desc: "use client auth example",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
projectId: another-project
useClientOAuth: true
`,
want: map[string]sources.SourceConfig{
"my-gda-instance": cloudgda.Config{
Name: "my-gda-instance",
Kind: cloudgda.SourceKind,
ProjectID: "another-project",
UseClientOAuth: true,
},
},
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Sources) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
}
})
}
}
func TestFailParseFromYaml(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
err string
}{
{
desc: "missing projectId",
in: `
sources:
my-gda-instance:
kind: cloud-gemini-data-analytics
`,
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Sources server.SourceConfigs `yaml:"sources"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
errStr := err.Error()
if errStr != tc.err {
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
}
})
}
}
func TestInitialize(t *testing.T) {
// Create a dummy credentials file for testing ADC
credFile := filepath.Join(t.TempDir(), "application_default_credentials.json")
dummyCreds := `{
"client_id": "foo",
"client_secret": "bar",
"refresh_token": "baz",
"type": "authorized_user"
}`
if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil {
t.Fatalf("failed to write dummy credentials file: %v", err)
}
t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile)
// Use ContextWithUserAgent to avoid "unable to retrieve user agent" error
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
tracer := noop.NewTracerProvider().Tracer("test")
tcs := []struct {
desc string
cfg cloudgda.Config
wantClientOAuth bool
}{
{
desc: "initialize with ADC",
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
wantClientOAuth: false,
},
{
desc: "initialize with client OAuth",
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
wantClientOAuth: true,
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
src, err := tc.cfg.Initialize(ctx, tracer)
if err != nil {
t.Fatalf("failed to initialize source: %v", err)
}
gdaSrc, ok := src.(*cloudgda.Source)
if !ok {
t.Fatalf("expected *cloudgda.Source, got %T", src)
}
// Check that the client is non-nil
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
}
// Test UseClientAuthorization method
if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth {
t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization())
}
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}
}
})
}
}

View File

@@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) {
return util.WithLogger(ctx, logger), nil
}
// ContextWithUserAgent creates a new context with a specified user agent string.
func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context {
return util.WithUserAgent(ctx, userAgent)
}
// WaitForString waits until the server logs a single line that matches the provided regex.
// returns the output of whatever the server sent so far.
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {

View File

@@ -0,0 +1,205 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
const kind string = "cloud-gemini-data-analytics-query"
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description" validate:"required"`
Location string `yaml:"location" validate:"required"`
Context *QueryDataContext `yaml:"context" validate:"required"`
GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(*cloudgdasrc.Source)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
}
// Define the parameters for the Gemini Data Analytics Query API
// The prompt is the only input parameter.
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
return Tool{
Config: cfg,
AllParams: allParameters,
Source: s,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Config
AllParams parameters.Parameters
Source *cloudgdasrc.Source
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
// Invoke executes the tool logic
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
prompt, ok := paramsMap["prompt"].(string)
if !ok {
return nil, fmt.Errorf("prompt parameter not found or not a string")
}
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: prompt,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
}
// Parse the access token if provided
var tokenStr string
if t.RequiresClientAuthorization(resourceMgr) {
var err error
tokenStr, err = accessToken.ParseBearerToken()
if err != nil {
return nil, fmt.Errorf("error parsing access token: %w", err)
}
}
client, err := t.Source.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.AllParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
return t.Source.UseClientAuthorization()
}
func (t Tool) GetAuthTokenHeaderName() string {
return "Authorization"
}

View File

@@ -0,0 +1,379 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
)
func TestParseFromYaml(t *testing.T) {
t.Parallel()
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
my-gda-query-tool:
kind: cloud-gemini-data-analytics-query
source: gda-api-source
description: Test Description
location: us-central1
context:
datasourceReferences:
spannerReference:
databaseReference:
projectId: "cloud-db-nl2sql"
region: "us-central1"
instanceId: "evalbench"
databaseId: "financial"
engine: "GOOGLE_SQL"
agentContextReference:
contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates"
generationOptions:
generateQueryResult: true
`,
want: map[string]tools.ToolConfig{
"my-gda-query-tool": cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "gda-api-source",
Description: "Test Description",
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
},
},
},
},
}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
}
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
return rt.Next.RoundTrip(&newReq)
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
srcs := map[string]sources.Source{
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
}
tcs := []struct {
desc string
cfg cloudgdatool.Config
expectErr bool
}{
{
desc: "successful initialization",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "gda-api-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: false,
},
{
desc: "missing source",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "non-existent-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
{
desc: "incompatible source kind",
cfg: cloudgdatool.Config{
Name: "my-gda-query-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "incompatible-source",
Description: "Test Description",
Location: "us-central1",
},
expectErr: true,
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
tool, err := tc.cfg.Initialize(srcs)
if tc.expectErr && err == nil {
t.Fatalf("expected an error but got none")
}
if !tc.expectErr && err != nil {
t.Fatalf("did not expect an error but got: %v", err)
}
if !tc.expectErr {
// Basic sanity check on the returned tool
_ = tool // Avoid unused variable error
}
})
}
}
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
}
authClient := &http.Client{Transport: authTransport}
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
}
srcs := map[string]sources.Source{
"mock-gda-source": mockGdaSource,
}
// Initialize the tool config with context
toolCfg := cloudgdatool.Config{
Name: "query-data-tool",
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: "us-central1", // Set location for the test
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
},
}
tool, err := toolCfg.Initialize(srcs)
if err != nil {
t.Fatalf("failed to initialize tool: %v", err)
}
// Prepare parameters for invocation - ONLY prompt
params := parameters.ParamValues{
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
// Invoke the tool
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
}
}

View File

@@ -0,0 +1,116 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -0,0 +1,233 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda_test
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
"github.com/googleapis/genai-toolbox/tests"
)
var (
cloudGdaToolKind = "cloud-gemini-data-analytics-query"
)
type cloudGdaTransport struct {
transport http.RoundTripper
url *url.URL
}
func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") {
req.URL.Scheme = t.url.Scheme
req.URL.Host = t.url.Host
}
return t.transport.RoundTrip(req)
}
type masterHandler struct {
t *testing.T
}
func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
h.t.Errorf("User-Agent header not found")
}
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Verify URL structure
// Expected: /v1beta/projects/{project}/locations/global:queryData
if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") {
h.t.Errorf("unexpected URL path: %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
var reqBody cloudgda.QueryDataRequest
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
h.t.Fatalf("failed to decode request body: %v", err)
}
if reqBody.Prompt == "" {
http.Error(w, "missing prompt", http.StatusBadRequest)
return
}
response := map[string]any{
"queryResult": "SELECT * FROM table;",
"naturalLanguageAnswer": "Here is the answer.",
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(response); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func TestCloudGdaToolEndpoints(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
handler := &masterHandler{t: t}
server := httptest.NewServer(handler)
defer server.Close()
serverURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
originalTransport := http.DefaultClient.Transport
if originalTransport == nil {
originalTransport = http.DefaultTransport
}
http.DefaultClient.Transport = &cloudGdaTransport{
transport: originalTransport,
url: serverURL,
}
t.Cleanup(func() {
http.DefaultClient.Transport = originalTransport
})
var args []string
toolsFile := getCloudGdaToolsConfig()
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
if err != nil {
t.Fatalf("command initialization returned an error: %s", err)
}
defer cleanup()
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
if err != nil {
t.Logf("toolbox command logs: \n%s", out)
t.Fatalf("toolbox didn't start successfully: %s", err)
}
toolName := "cloud-gda-query"
// 1. RunToolGetTestByName
expectedManifest := map[string]any{
toolName: map[string]any{
"description": "Test GDA Tool",
"parameters": []any{
map[string]any{
"name": "prompt",
"type": "string",
"description": "The natural language question to ask.",
"required": true,
"authSources": []any{},
},
},
"authRequired": []any{},
},
}
tests.RunToolGetTestByName(t, toolName, expectedManifest)
// 2. RunToolInvokeParametersTest
params := []byte(`{"prompt": "test question"}`)
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
// 3. Manual MCP Tool Call Test
// Initialize MCP session
sessionId := tests.RunInitialize(t, "2024-11-05")
// Construct MCP Request
mcpReq := jsonrpc.JSONRPCRequest{
Jsonrpc: "2.0",
Id: "test-mcp-call",
Request: jsonrpc.Request{
Method: "tools/call",
},
Params: map[string]any{
"name": toolName,
"arguments": map[string]any{
"prompt": "test question",
},
},
}
reqBytes, _ := json.Marshal(mcpReq)
headers := map[string]string{}
if sessionId != "" {
headers["Mcp-Session-Id"] = sessionId
}
// Send Request
resp, respBody := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBytes), headers)
if resp.StatusCode != http.StatusOK {
t.Fatalf("MCP request failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Check Response
respStr := string(respBody)
if !strings.Contains(respStr, "SELECT * FROM table;") {
t.Errorf("MCP response does not contain expected query result: %s", respStr)
}
}
func getCloudGdaToolsConfig() map[string]any {
// Mocked responses and a dummy `projectId` are used in this integration
// test due to limited project-specific allowlisting. API functionality is
// verified via internal monitoring; this test specifically validates the
// integration flow between the source and the tool.
return map[string]any{
"sources": map[string]any{
"my-gda-source": map[string]any{
"kind": "cloud-gemini-data-analytics",
"projectId": "test-project",
},
},
"tools": map[string]any{
"cloud-gda-query": map[string]any{
"kind": cloudGdaToolKind,
"source": "my-gda-source",
"description": "Test GDA Tool",
"location": "us-central1",
"context": map[string]any{
"datasourceReferences": map[string]any{
"spannerReference": map[string]any{
"databaseReference": map[string]any{
"projectId": "test-project",
"instanceId": "test-instance",
"databaseId": "test-db",
"engine": "GOOGLE_SQL",
},
},
},
},
},
},
}
}