mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-09 07:28:05 -05:00
Merge branch 'main' into fix_dashboard_filter
This commit is contained in:
@@ -212,6 +212,26 @@ steps:
|
||||
bigquery \
|
||||
bigquery
|
||||
|
||||
- id: "cloud-gda"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "CLOUD_GDA_PROJECT=$PROJECT_ID"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud Gemini Data Analytics" \
|
||||
cloudgda \
|
||||
cloudgda
|
||||
|
||||
- id: "dataplex"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
|
||||
@@ -73,6 +73,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"
|
||||
|
||||
40
docs/en/resources/sources/cloud-gda.md
Normal file
40
docs/en/resources/sources/cloud-gda.md
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: "Gemini Data Analytics"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API.
|
||||
aliases:
|
||||
- /resources/sources/cloud-gemini-data-analytics
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API.
|
||||
|
||||
Authentication can be handled in two ways:
|
||||
|
||||
1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method.
|
||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-gda-source:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: my-project-id
|
||||
|
||||
my-oauth-gda-source:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: my-project-id
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-gemini-data-analytics". |
|
||||
| projectId | string | true | The Google Cloud Project ID where the API is enabled. |
|
||||
| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. |
|
||||
7
docs/en/resources/tools/cloudgda/_index.md
Normal file
7
docs/en/resources/tools/cloudgda/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Gemini Data Analytics"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools for Gemini Data Analytics.
|
||||
---
|
||||
92
docs/en/resources/tools/cloudgda/cloud-gda-query.md
Normal file
92
docs/en/resources/tools/cloudgda/cloud-gda-query.md
Normal file
@@ -0,0 +1,92 @@
|
||||
---
|
||||
title: "Gemini Data Analytics QueryData"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API.
|
||||
aliases:
|
||||
- /resources/tools/cloud-gemini-data-analytics-query
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
my-gda-query-tool:
|
||||
kind: cloud-gemini-data-analytics-query
|
||||
source: my-gda-source
|
||||
description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations."
|
||||
location: ${your_database_location}
|
||||
context:
|
||||
datasourceReferences:
|
||||
cloudSqlReference:
|
||||
databaseReference:
|
||||
projectId: "${your_project_id}"
|
||||
region: "${your_database_instance_region}"
|
||||
instanceId: "${your_database_instance_id}"
|
||||
databaseId: "${your_database_name}"
|
||||
engine: "POSTGRESQL"
|
||||
agentContextReference:
|
||||
contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id}
|
||||
generationOptions:
|
||||
generateQueryResult: true
|
||||
generateNaturalLanguageAnswer: true
|
||||
generateExplanation: true
|
||||
generateDisambiguationQuestion: true
|
||||
```
|
||||
|
||||
### Usage Flow
|
||||
|
||||
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
|
||||
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
||||
|
||||
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
||||
|
||||
**Example Input Prompt:**
|
||||
|
||||
```text
|
||||
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
||||
```
|
||||
|
||||
**Example API Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'",
|
||||
"intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.",
|
||||
"naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.",
|
||||
"queryResult": {
|
||||
"columns": [
|
||||
{
|
||||
"type": "INT64"
|
||||
}
|
||||
],
|
||||
"rows": [
|
||||
{
|
||||
"values": [
|
||||
{
|
||||
"value": "84"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"totalRowCount": "1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| kind | string | true | Must be "cloud-gemini-data-analytics-query". |
|
||||
| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. |
|
||||
| description | string | true | A description of the tool's purpose. |
|
||||
| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. |
|
||||
| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. |
|
||||
| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. |
|
||||
4
go.mod
4
go.mod
@@ -12,7 +12,7 @@ require (
|
||||
cloud.google.com/go/dataplex v1.28.0
|
||||
cloud.google.com/go/dataproc/v2 v2.15.0
|
||||
cloud.google.com/go/firestore v1.20.0
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0
|
||||
cloud.google.com/go/longrunning v0.7.0
|
||||
cloud.google.com/go/spanner v1.86.1
|
||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
||||
@@ -181,7 +181,7 @@ require (
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||
google.golang.org/grpc v1.76.0 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
|
||||
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
|
||||
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
|
||||
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ=
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
||||
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
|
||||
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
|
||||
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
|
||||
@@ -1990,8 +1990,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0=
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
|
||||
154
internal/sources/cloudgda/cloud_gda.go
Normal file
154
internal/sources/cloudgda/cloud_gda.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ProjectID string `yaml:"projectId" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Gemini Data Analytics Source instance.
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
ua, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
|
||||
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: client,
|
||||
BaseURL: Endpoint,
|
||||
userAgent: ua,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
BaseURL string
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: s.userAgent,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
return baseClient, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
213
internal/sources/cloudgda/cloud_gda_test.go
Normal file
213
internal/sources/cloudgda/cloud_gda_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: test-project-id
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "test-project-id",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: another-project
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "another-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "missing projectId",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
`,
|
||||
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
// Create a dummy credentials file for testing ADC
|
||||
credFile := filepath.Join(t.TempDir(), "application_default_credentials.json")
|
||||
dummyCreds := `{
|
||||
"client_id": "foo",
|
||||
"client_secret": "bar",
|
||||
"refresh_token": "baz",
|
||||
"type": "authorized_user"
|
||||
}`
|
||||
if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil {
|
||||
t.Fatalf("failed to write dummy credentials file: %v", err)
|
||||
}
|
||||
t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile)
|
||||
|
||||
// Use ContextWithUserAgent to avoid "unable to retrieve user agent" error
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("test")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgda.Config
|
||||
wantClientOAuth bool
|
||||
}{
|
||||
{
|
||||
desc: "initialize with ADC",
|
||||
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
|
||||
wantClientOAuth: false,
|
||||
},
|
||||
{
|
||||
desc: "initialize with client OAuth",
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
wantClientOAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize source: %v", err)
|
||||
}
|
||||
|
||||
gdaSrc, ok := src.(*cloudgda.Source)
|
||||
if !ok {
|
||||
t.Fatalf("expected *cloudgda.Source, got %T", src)
|
||||
}
|
||||
|
||||
// Check that the client is non-nil
|
||||
if gdaSrc.Client == nil && !tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for ADC, got nil")
|
||||
}
|
||||
// When client OAuth is true, the source's client should be initialized with a base HTTP client
|
||||
// that includes the user agent round tripper, but not the OAuth token. The token-aware
|
||||
// client is created by GetClient.
|
||||
if gdaSrc.Client == nil && tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
|
||||
}
|
||||
|
||||
// Test UseClientAuthorization method
|
||||
if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth {
|
||||
t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization())
|
||||
}
|
||||
|
||||
// Test GetClient with accessToken for client OAuth scenarios
|
||||
if tc.wantClientOAuth {
|
||||
client, err := gdaSrc.GetClient(ctx, "dummy-token")
|
||||
if err != nil {
|
||||
t.Fatalf("GetClient with token failed: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
|
||||
}
|
||||
// Ensure passing empty token with UseClientOAuth enabled returns error
|
||||
_, err = gdaSrc.GetClient(ctx, "")
|
||||
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
|
||||
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) {
|
||||
return util.WithLogger(ctx, logger), nil
|
||||
}
|
||||
|
||||
// ContextWithUserAgent creates a new context with a specified user agent string.
|
||||
func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context {
|
||||
return util.WithUserAgent(ctx, userAgent)
|
||||
}
|
||||
|
||||
// WaitForString waits until the server logs a single line that matches the provided regex.
|
||||
// returns the output of whatever the server sent so far.
|
||||
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {
|
||||
|
||||
205
internal/tools/cloudgda/cloudgda.go
Normal file
205
internal/tools/cloudgda/cloudgda.go
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
const kind string = "cloud-gemini-data-analytics-query"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Location string `yaml:"location" validate:"required"`
|
||||
Context *QueryDataContext `yaml:"context" validate:"required"`
|
||||
GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudgdasrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters
|
||||
Source *cloudgdasrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("prompt parameter not found or not a string")
|
||||
}
|
||||
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
|
||||
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
Prompt: prompt,
|
||||
Context: t.Context,
|
||||
GenerationOptions: t.GenerationOptions,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the access token if provided
|
||||
var tokenStr string
|
||||
if t.RequiresClientAuthorization(resourceMgr) {
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
}
|
||||
379
internal/tools/cloudgda/cloudgda_test.go
Normal file
379
internal/tools/cloudgda/cloudgda_test.go
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
my-gda-query-tool:
|
||||
kind: cloud-gemini-data-analytics-query
|
||||
source: gda-api-source
|
||||
description: Test Description
|
||||
location: us-central1
|
||||
context:
|
||||
datasourceReferences:
|
||||
spannerReference:
|
||||
databaseReference:
|
||||
projectId: "cloud-db-nl2sql"
|
||||
region: "us-central1"
|
||||
instanceId: "evalbench"
|
||||
databaseId: "financial"
|
||||
engine: "GOOGLE_SQL"
|
||||
agentContextReference:
|
||||
contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates"
|
||||
generationOptions:
|
||||
generateQueryResult: true
|
||||
`,
|
||||
want: map[string]tools.ToolConfig{
|
||||
"my-gda-query-tool": cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "gda-api-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
AuthRequired: []string{},
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Tools) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
|
||||
type authRoundTripper struct {
|
||||
Token string
|
||||
Next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
newReq.Header.Set("Authorization", rt.Token)
|
||||
if rt.Next == nil {
|
||||
return http.DefaultTransport.RoundTrip(&newReq)
|
||||
}
|
||||
return rt.Next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
type mockSource struct {
|
||||
kind string
|
||||
client *http.Client // Can be used to inject a specific client
|
||||
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
|
||||
config cloudgdasrc.Config // to return from ToConfig
|
||||
}
|
||||
|
||||
func (m *mockSource) SourceKind() string { return m.kind }
|
||||
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
|
||||
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
|
||||
if m.client != nil {
|
||||
return m.client, nil
|
||||
}
|
||||
// Default client for testing if not explicitly set
|
||||
transport := &http.Transport{}
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token", // Dummy token
|
||||
Next: transport,
|
||||
}
|
||||
return &http.Client{Transport: authTransport}, nil
|
||||
}
|
||||
func (m *mockSource) UseClientAuthorization() bool { return false }
|
||||
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
||||
return m, nil
|
||||
}
|
||||
func (m *mockSource) BaseURL() string { return m.baseURL }
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srcs := map[string]sources.Source{
|
||||
"gda-api-source": &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: &http.Client{},
|
||||
BaseURL: cloudgdasrc.Endpoint,
|
||||
},
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
desc: "successful initialization",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "gda-api-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "incompatible source kind",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "incompatible-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Add an incompatible source for testing
|
||||
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool, err := tc.cfg.Initialize(srcs)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
t.Fatalf("did not expect an error but got: %v", err)
|
||||
}
|
||||
if !tc.expectErr {
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Mock the HTTP client and server for Invoke testing
|
||||
serverMux := http.NewServeMux()
|
||||
// Update expected URL path to include the location "us-central1"
|
||||
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST method, got %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read and unmarshal the request body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read request body: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var reqPayload cloudgdatool.QueryDataRequest
|
||||
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
|
||||
t.Errorf("failed to unmarshal request payload: %v", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify expected fields
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Errorf("expected Authorization header, got empty")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
|
||||
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
|
||||
}
|
||||
|
||||
// Verify payload's parent uses the tool's configured location
|
||||
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
|
||||
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
|
||||
}
|
||||
|
||||
// Verify context from config
|
||||
if reqPayload.Context == nil ||
|
||||
reqPayload.Context.DatasourceReferences == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
|
||||
t.Errorf("unexpected context: %v", reqPayload.Context)
|
||||
}
|
||||
|
||||
// Verify generation options from config
|
||||
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
|
||||
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
|
||||
}
|
||||
|
||||
// Simulate a successful response
|
||||
resp := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
mockServer := httptest.NewServer(serverMux)
|
||||
defer mockServer.Close()
|
||||
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
|
||||
// Create an authenticated client that uses the mock server
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token",
|
||||
Next: mockServer.Client().Transport,
|
||||
}
|
||||
authClient := &http.Client{Transport: authTransport}
|
||||
|
||||
// Create a real cloudgdasrc.Source but inject the authenticated client
|
||||
mockGdaSource := &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: authClient,
|
||||
BaseURL: mockServer.URL,
|
||||
}
|
||||
srcs := map[string]sources.Source{
|
||||
"mock-gda-source": mockGdaSource,
|
||||
}
|
||||
|
||||
// Initialize the tool config with context
|
||||
toolCfg := cloudgdatool.Config{
|
||||
Name: "query-data-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "mock-gda-source",
|
||||
Description: "Query Gemini Data Analytics",
|
||||
Location: "us-central1", // Set location for the test
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
tool, err := toolCfg.Initialize(srcs)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize tool: %v", err)
|
||||
}
|
||||
|
||||
// Prepare parameters for invocation - ONLY prompt
|
||||
params := parameters.ParamValues{
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
// Validate the result
|
||||
expectedResult := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
|
||||
if !cmp.Equal(expectedResult, result) {
|
||||
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
|
||||
}
|
||||
}
|
||||
116
internal/tools/cloudgda/types.go
Normal file
116
internal/tools/cloudgda/types.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda
|
||||
|
||||
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
|
||||
|
||||
// QueryDataRequest represents the JSON body for the queryData API
|
||||
type QueryDataRequest struct {
|
||||
Parent string `json:"parent"`
|
||||
Prompt string `json:"prompt"`
|
||||
Context *QueryDataContext `json:"context,omitempty"`
|
||||
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
|
||||
}
|
||||
|
||||
// QueryDataContext reflects the proto definition for the query context.
|
||||
type QueryDataContext struct {
|
||||
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
|
||||
}
|
||||
|
||||
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
|
||||
type DatasourceReferences struct {
|
||||
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
|
||||
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
|
||||
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerReference reflects the proto definition for Spanner database reference.
|
||||
type SpannerReference struct {
|
||||
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
|
||||
type SpannerDatabaseReference struct {
|
||||
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerEngine represents the engine of the Spanner instance.
|
||||
type SpannerEngine string
|
||||
|
||||
const (
|
||||
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
|
||||
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
|
||||
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
|
||||
)
|
||||
|
||||
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBReference struct {
|
||||
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBDatabaseReference struct {
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLReference struct {
|
||||
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLDatabaseReference struct {
|
||||
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLEngine represents the engine of the Cloud SQL instance.
|
||||
type CloudSQLEngine string
|
||||
|
||||
const (
|
||||
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
|
||||
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
|
||||
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
|
||||
)
|
||||
|
||||
// AgentContextReference reflects the proto definition for agent context.
|
||||
type AgentContextReference struct {
|
||||
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationOptions reflects the proto definition for generation options.
|
||||
type GenerationOptions struct {
|
||||
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
|
||||
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
|
||||
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
|
||||
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
|
||||
}
|
||||
233
tests/cloudgda/cloud_gda_integration_test.go
Normal file
233
tests/cloudgda/cloud_gda_integration_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudGdaToolKind = "cloud-gemini-data-analytics-query"
|
||||
)
|
||||
|
||||
type cloudGdaTransport struct {
|
||||
transport http.RoundTripper
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") {
|
||||
req.URL.Scheme = t.url.Scheme
|
||||
req.URL.Host = t.url.Host
|
||||
}
|
||||
return t.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
type masterHandler struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
h.t.Errorf("User-Agent header not found")
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify URL structure
|
||||
// Expected: /v1beta/projects/{project}/locations/global:queryData
|
||||
if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") {
|
||||
h.t.Errorf("unexpected URL path: %s", r.URL.Path)
|
||||
http.Error(w, "Not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody cloudgda.QueryDataRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
h.t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
if reqBody.Prompt == "" {
|
||||
http.Error(w, "missing prompt", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"queryResult": "SELECT * FROM table;",
|
||||
"naturalLanguageAnswer": "Here is the answer.",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudGdaToolEndpoints(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
handler := &masterHandler{t: t}
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
serverURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
originalTransport := http.DefaultClient.Transport
|
||||
if originalTransport == nil {
|
||||
originalTransport = http.DefaultTransport
|
||||
}
|
||||
http.DefaultClient.Transport = &cloudGdaTransport{
|
||||
transport: originalTransport,
|
||||
url: serverURL,
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient.Transport = originalTransport
|
||||
})
|
||||
|
||||
var args []string
|
||||
toolsFile := getCloudGdaToolsConfig()
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
toolName := "cloud-gda-query"
|
||||
|
||||
// 1. RunToolGetTestByName
|
||||
expectedManifest := map[string]any{
|
||||
toolName: map[string]any{
|
||||
"description": "Test GDA Tool",
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "prompt",
|
||||
"type": "string",
|
||||
"description": "The natural language question to ask.",
|
||||
"required": true,
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
}
|
||||
tests.RunToolGetTestByName(t, toolName, expectedManifest)
|
||||
|
||||
// 2. RunToolInvokeParametersTest
|
||||
params := []byte(`{"prompt": "test question"}`)
|
||||
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
|
||||
|
||||
// 3. Manual MCP Tool Call Test
|
||||
// Initialize MCP session
|
||||
sessionId := tests.RunInitialize(t, "2024-11-05")
|
||||
|
||||
// Construct MCP Request
|
||||
mcpReq := jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "test-mcp-call",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": map[string]any{
|
||||
"prompt": "test question",
|
||||
},
|
||||
},
|
||||
}
|
||||
reqBytes, _ := json.Marshal(mcpReq)
|
||||
|
||||
headers := map[string]string{}
|
||||
if sessionId != "" {
|
||||
headers["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
// Send Request
|
||||
resp, respBody := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBytes), headers)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("MCP request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check Response
|
||||
respStr := string(respBody)
|
||||
if !strings.Contains(respStr, "SELECT * FROM table;") {
|
||||
t.Errorf("MCP response does not contain expected query result: %s", respStr)
|
||||
}
|
||||
}
|
||||
|
||||
func getCloudGdaToolsConfig() map[string]any {
|
||||
// Mocked responses and a dummy `projectId` are used in this integration
|
||||
// test due to limited project-specific allowlisting. API functionality is
|
||||
// verified via internal monitoring; this test specifically validates the
|
||||
// integration flow between the source and the tool.
|
||||
return map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-gda-source": map[string]any{
|
||||
"kind": "cloud-gemini-data-analytics",
|
||||
"projectId": "test-project",
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"cloud-gda-query": map[string]any{
|
||||
"kind": cloudGdaToolKind,
|
||||
"source": "my-gda-source",
|
||||
"description": "Test GDA Tool",
|
||||
"location": "us-central1",
|
||||
"context": map[string]any{
|
||||
"datasourceReferences": map[string]any{
|
||||
"spannerReference": map[string]any{
|
||||
"databaseReference": map[string]any{
|
||||
"projectId": "test-project",
|
||||
"instanceId": "test-instance",
|
||||
"databaseId": "test-db",
|
||||
"engine": "GOOGLE_SQL",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user