Compare commits

...

3 Commits

Author SHA1 Message Date
Juexin Wang
172dd6b19b refactor(tools/cloudgda): update to use google-cloud-go sdk types
Removes types.go and uses geminidataanalyticspb types with wrapper structs for YAML decoding.
2026-01-14 17:06:43 -08:00
Juexin Wang
9f5b04cf73 Refactor Cloud GDA source to support per-request client authorization 2026-01-14 16:10:50 -08:00
Juexin Wang
66d6b58c4f intorduce Go SDK for GDA tool and source 2026-01-14 16:10:50 -08:00
7 changed files with 207 additions and 382 deletions

2
go.mod
View File

@@ -12,7 +12,7 @@ require (
cloud.google.com/go/dataplex v1.28.0
cloud.google.com/go/dataproc/v2 v2.15.0
cloud.google.com/go/firestore v1.20.0
cloud.google.com/go/geminidataanalytics v0.3.0
cloud.google.com/go/geminidataanalytics v0.5.0
cloud.google.com/go/longrunning v0.7.0
cloud.google.com/go/spanner v1.86.1
github.com/ClickHouse/clickhouse-go/v2 v2.40.3

4
go.sum
View File

@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=

View File

@@ -14,23 +14,20 @@
package cloudgda
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/util"
"go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
)
const SourceKind string = "cloud-gemini-data-analytics"
const Endpoint string = "https://geminidataanalytics.googleapis.com"
// validate interface
var _ sources.SourceConfig = Config{}
@@ -67,29 +64,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
}
var client *http.Client
if r.UseClientOAuth {
client = &http.Client{
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
}
} else {
// Use Application Default Credentials
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
if err != nil {
return nil, fmt.Errorf("failed to find default credentials: %w", err)
}
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
client = baseClient
}
s := &Source{
Config: r,
Client: client,
BaseURL: Endpoint,
userAgent: ua,
}
if !r.UseClientOAuth {
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
if err != nil {
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
}
s.Client = client
}
return s, nil
}
@@ -97,8 +84,7 @@ var _ sources.Source = &Source{}
type Source struct {
Config
Client *http.Client
BaseURL string
Client *geminidataanalytics.DataChatClient
userAgent string
}
@@ -114,63 +100,34 @@ func (s *Source) GetProjectID() string {
return s.ProjectID
}
func (s *Source) GetBaseURL() string {
return s.BaseURL
}
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
if s.UseClientOAuth {
if accessToken == "" {
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: accessToken}
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
return baseClient, nil
}
return s.Client, nil
}
func (s *Source) UseClientAuthorization() bool {
return s.UseClientOAuth
}
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
// The API endpoint itself always uses the "global" location.
apiLocation := "global"
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
client, err := s.GetClient(ctx, tokenStr)
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
client, cleanup, err := s.GetClient(ctx, tokenStr)
if err != nil {
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
return nil, err
}
defer cleanup()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
}
var result map[string]any
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return result, nil
return client.QueryData(ctx, req)
}
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
if s.UseClientOAuth {
if tokenStr == "" {
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
}
token := &oauth2.Token{AccessToken: tokenStr}
client, err := geminidataanalytics.NewDataChatClient(ctx,
option.WithUserAgent(s.userAgent),
option.WithTokenSource(oauth2.StaticTokenSource(token)),
)
if err != nil {
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
}
return client, func() { client.Close() }, nil
}
return s.Client, func() {}, nil
}

View File

@@ -181,11 +181,9 @@ func TestInitialize(t *testing.T) {
if gdaSrc.Client == nil && !tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for ADC, got nil")
}
// When client OAuth is true, the source's client should be initialized with a base HTTP client
// that includes the user agent round tripper, but not the OAuth token. The token-aware
// client is created by GetClient.
if gdaSrc.Client == nil && tc.wantClientOAuth {
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
// When client OAuth is true, the source's client should be nil.
if gdaSrc.Client != nil && tc.wantClientOAuth {
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
}
// Test UseClientAuthorization method
@@ -195,15 +193,16 @@ func TestInitialize(t *testing.T) {
// Test GetClient with accessToken for client OAuth scenarios
if tc.wantClientOAuth {
client, err := gdaSrc.GetClient(ctx, "dummy-token")
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
if err != nil {
t.Fatalf("GetClient with token failed: %v", err)
}
defer cleanup()
if client == nil {
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
}
// Ensure passing empty token with UseClientOAuth enabled returns error
_, err = gdaSrc.GetClient(ctx, "")
_, _, err = gdaSrc.GetClient(ctx, "")
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
}

View File

@@ -19,11 +19,13 @@ import (
"encoding/json"
"fmt"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
"github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"google.golang.org/protobuf/encoding/protojson"
)
const kind string = "cloud-gemini-data-analytics-query"
@@ -60,7 +62,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
GetProjectID() string
UseClientAuthorization() bool
RunQuery(context.Context, string, []byte) (any, error)
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
}
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
type QueryDataContext struct {
*geminidataanalyticspb.QueryDataContext
}
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal context map: %w", err)
}
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
}
return nil
}
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
type GenerationOptions struct {
*geminidataanalyticspb.GenerationOptions
}
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
var raw map[string]any
if err := yaml.Unmarshal(b, &raw); err != nil {
return err
}
jsonBytes, err := json.Marshal(raw)
if err != nil {
return fmt.Errorf("failed to marshal generation options map: %w", err)
}
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
}
return nil
}
type Config struct {
@@ -97,12 +141,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
return Tool{
t := Tool{
Config: cfg,
AllParams: allParameters,
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
mcpManifest: mcpManifest,
}, nil
}
return t, nil
}
// validate interface
@@ -145,18 +191,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
// The parent in the request payload uses the tool's configured location.
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
payload := &QueryDataRequest{
Parent: payloadParent,
Prompt: query,
Context: t.Context,
GenerationOptions: t.GenerationOptions,
req := &geminidataanalyticspb.QueryDataRequest{
Parent: payloadParent,
Prompt: query,
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
if t.Context != nil {
req.Context = t.Context.QueryDataContext
}
return source.RunQuery(ctx, tokenStr, bodyBytes)
if t.GenerationOptions != nil {
req.GenerationOptions = t.GenerationOptions.GenerationOptions
}
return source.RunQuery(ctx, tokenStr, req)
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {

View File

@@ -16,19 +16,16 @@ package cloudgda_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools"
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
@@ -74,23 +71,29 @@ func TestParseFromYaml(t *testing.T) {
Location: "us-central1",
AuthRequired: []string{},
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
},
},
@@ -108,68 +111,63 @@ func TestParseFromYaml(t *testing.T) {
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Tools) {
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
}
})
}
}
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
type authRoundTripper struct {
Token string
Next http.RoundTripper
// fakeSource implements the compatibleSource interface for testing.
type fakeSource struct {
projectID string
useClientOAuth bool
expectedQuery string
expectedParent string
response *geminidataanalyticspb.QueryDataResponse
}
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := *req
newReq.Header = make(http.Header)
for k, v := range req.Header {
newReq.Header[k] = v
}
newReq.Header.Set("Authorization", rt.Token)
if rt.Next == nil {
return http.DefaultTransport.RoundTrip(&newReq)
}
return rt.Next.RoundTrip(&newReq)
func (f *fakeSource) GetProjectID() string {
return f.projectID
}
type mockSource struct {
kind string
client *http.Client // Can be used to inject a specific client
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
config cloudgdasrc.Config // to return from ToConfig
func (f *fakeSource) UseClientAuthorization() bool {
return f.useClientOAuth
}
func (m *mockSource) SourceKind() string { return m.kind }
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
if m.client != nil {
return m.client, nil
}
// Default client for testing if not explicitly set
transport := &http.Transport{}
authTransport := &authRoundTripper{
Token: "Bearer test-access-token", // Dummy token
Next: transport,
}
return &http.Client{Transport: authTransport}, nil
func (f *fakeSource) SourceKind() string {
return "fake-gda-source"
}
func (m *mockSource) UseClientAuthorization() bool { return false }
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return m, nil
func (f *fakeSource) ToConfig() sources.SourceConfig {
return nil
}
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
return f, nil
}
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
if req.Prompt != f.expectedQuery {
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
}
if req.Parent != f.expectedParent {
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
}
// Basic validation of context/options could be added here if needed,
// but the test case mainly checks if they are passed correctly via successful invocation.
return f.response, nil
}
func (m *mockSource) BaseURL() string { return m.baseURL }
func TestInitialize(t *testing.T) {
t.Parallel()
// Minimal fake source
fake := &fakeSource{projectID: "test-project"}
srcs := map[string]sources.Source{
"gda-api-source": &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: &http.Client{},
BaseURL: cloudgdasrc.Endpoint,
},
"gda-api-source": fake,
}
tcs := []struct {
@@ -188,9 +186,6 @@ func TestInitialize(t *testing.T) {
},
}
// Add an incompatible source for testing
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
for _, tc := range tcs {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
@@ -207,92 +202,27 @@ func TestInitialize(t *testing.T) {
func TestInvoke(t *testing.T) {
t.Parallel()
// Mock the HTTP client and server for Invoke testing
serverMux := http.NewServeMux()
// Update expected URL path to include the location "us-central1"
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
if r.Header.Get("Content-Type") != "application/json" {
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// Read and unmarshal the request body
bodyBytes, err := io.ReadAll(r.Body)
if err != nil {
t.Errorf("failed to read request body: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var reqPayload cloudgdatool.QueryDataRequest
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
t.Errorf("failed to unmarshal request payload: %v", err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
projectID := "test-project"
location := "us-central1"
query := "How many accounts who have region in Prague are eligible for loans?"
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
// Verify expected fields
if r.Header.Get("Authorization") == "" {
t.Errorf("expected Authorization header, got empty")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
}
// Verify payload's parent uses the tool's configured location
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
}
// Verify context from config
if reqPayload.Context == nil ||
reqPayload.Context.DatasourceReferences == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
t.Errorf("unexpected context: %v", reqPayload.Context)
}
// Verify generation options from config
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
}
// Simulate a successful response
resp := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
}
_ = json.NewEncoder(w).Encode(resp)
})
mockServer := httptest.NewServer(serverMux)
defer mockServer.Close()
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Create an authenticated client that uses the mock server
authTransport := &authRoundTripper{
Token: "Bearer test-access-token",
Next: mockServer.Client().Transport,
// Prepare expected response
expectedResp := &geminidataanalyticspb.QueryDataResponse{
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
}
authClient := &http.Client{Transport: authTransport}
// Create a real cloudgdasrc.Source but inject the authenticated client
mockGdaSource := &cloudgdasrc.Source{
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
Client: authClient,
BaseURL: mockServer.URL,
fake := &fakeSource{
projectID: projectID,
expectedQuery: query,
expectedParent: expectedParent,
response: expectedResp,
}
srcs := map[string]sources.Source{
"mock-gda-source": mockGdaSource,
"mock-gda-source": fake,
}
// Initialize the tool config with context
@@ -301,25 +231,31 @@ func TestInvoke(t *testing.T) {
Kind: "cloud-gemini-data-analytics-query",
Source: "mock-gda-source",
Description: "Query Gemini Data Analytics",
Location: "us-central1", // Set location for the test
Location: location,
Context: &cloudgdatool.QueryDataContext{
DatasourceReferences: &cloudgdatool.DatasourceReferences{
SpannerReference: &cloudgdatool.SpannerReference{
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
ProjectID: "cloud-db-nl2sql",
Region: "us-central1",
InstanceID: "evalbench",
DatabaseID: "financial",
Engine: cloudgdatool.SpannerEngineGoogleSQL,
},
AgentContextReference: &cloudgdatool.AgentContextReference{
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
SpannerReference: &geminidataanalyticspb.SpannerReference{
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
ProjectId: "cloud-db-nl2sql",
Region: "us-central1",
InstanceId: "evalbench",
DatabaseId: "financial",
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
},
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
},
},
},
},
},
},
GenerationOptions: &cloudgdatool.GenerationOptions{
GenerateQueryResult: true,
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
GenerateQueryResult: true,
},
},
}
@@ -330,24 +266,25 @@ func TestInvoke(t *testing.T) {
// Prepare parameters for invocation - ONLY query
params := parameters.ParamValues{
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
{Name: "query", Value: query},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
result, err := tool.Invoke(ctx, resourceMgr, params, "")
if err != nil {
t.Fatalf("tool invocation failed: %v", err)
}
// Validate the result
expectedResult := map[string]any{
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
if !ok {
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
}
if !cmp.Equal(expectedResult, result) {
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
}
}

View File

@@ -1,116 +0,0 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cloudgda
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
// QueryDataRequest represents the JSON body for the queryData API
type QueryDataRequest struct {
Parent string `json:"parent"`
Prompt string `json:"prompt"`
Context *QueryDataContext `json:"context,omitempty"`
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
}
// QueryDataContext reflects the proto definition for the query context.
type QueryDataContext struct {
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
}
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
type DatasourceReferences struct {
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
}
// SpannerReference reflects the proto definition for Spanner database reference.
type SpannerReference struct {
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
type SpannerDatabaseReference struct {
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// SpannerEngine represents the engine of the Spanner instance.
type SpannerEngine string
const (
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
)
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBReference struct {
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
type AlloyDBDatabaseReference struct {
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLReference struct {
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
}
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
type CloudSQLDatabaseReference struct {
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
Region string `json:"region,omitempty" yaml:"region,omitempty"`
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
}
// CloudSQLEngine represents the engine of the Cloud SQL instance.
type CloudSQLEngine string
const (
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
)
// AgentContextReference reflects the proto definition for agent context.
type AgentContextReference struct {
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
}
// GenerationOptions reflects the proto definition for generation options.
type GenerationOptions struct {
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
}