Compare commits

..

3 Commits

Author SHA1 Message Date
Wenxin Du
65a5942833 Merge branch 'main' into fix-redis-doc 2026-01-14 18:49:23 -05:00
Wenxin Du
c65f26b21a Merge branch 'main' into fix-redis-doc 2026-01-13 13:20:14 -05:00
duwenxin99
ea16e25e41 docs: fix redis array sample 2026-01-13 11:29:24 -05:00
10 changed files with 400 additions and 243 deletions

View File

@@ -41,13 +41,13 @@ tools:
### Usage Flow
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
**Example Input Query:**
**Example Input Prompt:**
```text
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.

View File

@@ -30,6 +30,10 @@ following config for example:
- name: userNames
type: array
description: The user names to be set.
items:
name: userName # the item name doesn't matter but it has to exist
type: string
description: username
```
If the input is an array of strings `["Alice", "Sid", "Bob"]`, The final command

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3

4
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

View File

@@ -14,20 +14,23 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"google.golang.org/api/option"
"golang.org/x/oauth2/google"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -64,19 +67,29 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
if !r.UseClientOAuth {
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}
return s, nil
}
@@ -84,7 +97,8 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *geminidataanalytics.DataChatClient
Client *http.Client
BaseURL string
userAgent string
}
@@ -100,34 +114,63 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
}
defer cleanup()
return client.QueryData(ctx, req)
}
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
client, err := geminidataanalytics.NewDataChatClient(ctx,
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
return s.Client, func() {}, nil
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
}

View File

@@ -181,9 +181,11 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
}
// Test UseClientAuthorization method
@@ -193,16 +195,15 @@ func TestInitialize(t *testing.T) {
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
client, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, _, err = gdaSrc.GetClient(ctx, "")
_, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}

View File

@@ -19,32 +19,15 @@ import (
"encoding/json"
"fmt"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind string = "cloud-gemini-data-analytics-query"
// Guidance is the tool guidance string.
const Guidance = `Tool guidance:
Inputs:
1. query: A natural language formulation of a database query.
Outputs: (all optional)
1. disambiguation_question: Clarification questions or comments where the tool needs the users' input.
2. generated_query: The generated query for the user query.
3. intent_explanation: An explanation for why the tool produced ` + "`generated_query`" + `.
4. query_result: The result of executing ` + "`generated_query`" + `.
5. natural_language_answer: The natural language answer that summarizes the ` + "`query`" + ` and ` + "`query_result`" + `.
Usage guidance:
1. If ` + "`disambiguation_question`" + ` is produced, then solicit the needed inputs from the user and try the tool with a new ` + "`query`" + ` that has the needed clarification.
2. If ` + "`natural_language_answer`" + ` is produced, use ` + "`intent_explanation`" + ` and ` + "`generated_query`" + ` to see if you need to clarify any assumptions for the user.`
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
@@ -62,49 +45,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
RunQuery(context.Context, string, []byte) (any, error)
}
type Config struct {
@@ -127,28 +68,19 @@ func (cfg Config) ToolConfigKind() string {
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// Define the parameters for the Gemini Data Analytics Query API
// The query is the only input parameter.
// The prompt is the only input parameter.
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("query", "A natural language formulation of a database query.", true),
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
}
// The input and outputs are for tool guidance, usage guidance is for multi-turn interaction.
guidance := Guidance
if cfg.Description != "" {
cfg.Description += "\n\n" + guidance
} else {
cfg.Description = guidance
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
t := Tool{
return Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}
return t, nil
}, nil
}
// validate interface
@@ -173,9 +105,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
}
paramsMap := params.AsMap()
query, ok := paramsMap["query"].(string)
prompt, ok := paramsMap["prompt"].(string)
if !ok {
return nil, fmt.Errorf("query parameter not found or not a string")
return nil, fmt.Errorf("prompt parameter not found or not a string")
}
// Parse the access token if provided
@@ -191,20 +123,18 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: prompt,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
}
if t.Context != nil {
req.Context = t.Context.QueryDataContext
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
}
if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}
return source.RunQuery(ctx, tokenStr, req)
return source.RunQuery(ctx, tokenStr, bodyBytes)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,16 +16,19 @@ package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
@@ -71,29 +74,23 @@ func TestParseFromYaml(t *testing.T) {
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
GenerateQueryResult: true,
},
},
},
@@ -111,63 +108,68 @@ func TestParseFromYaml(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// fakeSource implements the compatibleSource interface for testing.
type fakeSource struct {
projectID string
useClientOAuth bool
expectedQuery string
expectedParent string
response *geminidataanalyticspb.QueryDataResponse
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
}
func (f *fakeSource) GetProjectID() string {
return f.projectID
}
func (f *fakeSource) UseClientAuthorization() bool {
return f.useClientOAuth
}
func (f *fakeSource) SourceKind() string {
return "fake-gda-source"
}
func (f *fakeSource) ToConfig() sources.SourceConfig {
return nil
}
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return f, nil
}
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
if req.Prompt != f.expectedQuery {
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
if req.Parent != f.expectedParent {
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
// Basic validation of context/options could be added here if needed,
// but the test case mainly checks if they are passed correctly via successful invocation.
return f.response, nil
return rt.Next.RoundTrip(&newReq)
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
// Minimal fake source
fake := &fakeSource{projectID: "test-project"}
srcs := map[string]sources.Source{
"gda-api-source": fake,
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
}
tcs := []struct {
@@ -186,6 +188,9 @@ func TestInitialize(t *testing.T) {
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
@@ -202,27 +207,92 @@ func TestInitialize(t *testing.T) {
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
projectID := "test-project"
location := "us-central1"
query := "How many accounts who have region in Prague are eligible for loans?"
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Prepare expected response
expectedResp := &geminidataanalyticspb.QueryDataResponse{
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
}
authClient := &http.Client{Transport: authTransport}
fake := &fakeSource{
projectID: projectID,
expectedQuery: query,
expectedParent: expectedParent,
response: expectedResp,
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
}
srcs := map[string]sources.Source{
"mock-gda-source": fake,
"mock-gda-source": mockGdaSource,
}
// Initialize the tool config with context
@@ -231,31 +301,25 @@ func TestInvoke(t *testing.T) {
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: location,
Location: "us-central1", // Set location for the test
Context: &cloudgdatool.QueryDataContext{
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
GenerateQueryResult: true,
},
}
@@ -264,27 +328,26 @@ func TestInvoke(t *testing.T) {
t.Fatalf("failed to initialize tool: %v", err)
}
// Prepare parameters for invocation - ONLY query
// Prepare parameters for invocation - ONLY prompt
params := parameters.ParamValues{
{Name: "query", Value: query},
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "")
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
if !ok {
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
}
}

View File

@@ -0,0 +1,116 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}

View File

@@ -139,12 +139,12 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
// 1. RunToolGetTestByName
expectedManifest := map[string]any{
toolName: map[string]any{
"description": "Test GDA Tool\n\n" + cloudgda.Guidance,
"description": "Test GDA Tool",
"parameters": []any{
map[string]any{
"name": "query",
"name": "prompt",
"type": "string",
"description": "A natural language formulation of a database query.",
"description": "The natural language question to ask.",
"required": true,
"authSources": []any{},
},
@@ -155,7 +155,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
tests.RunToolGetTestByName(t, toolName, expectedManifest)
// 2. RunToolInvokeParametersTest
params := []byte(`{"query": "test question"}`)
params := []byte(`{"prompt": "test question"}`)
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
// 3. Manual MCP Tool Call Test
@@ -172,7 +172,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
Params: map[string]any{
"name": toolName,
"arguments": map[string]any{
"query": "test question",
"prompt": "test question",
},
},
}