mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-28 08:48:09 -05:00
Compare commits
7 Commits
remove-par
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e716efea6e | ||
|
|
dcfd056a30 | ||
|
|
59aa21729d | ||
|
|
e25ee6f165 | ||
|
|
9bb4eee494 | ||
|
|
9f5b04cf73 | ||
|
|
66d6b58c4f |
4
go.mod
4
go.mod
@@ -12,7 +12,7 @@ require (
|
|||||||
cloud.google.com/go/dataplex v1.28.0
|
cloud.google.com/go/dataplex v1.28.0
|
||||||
cloud.google.com/go/dataproc/v2 v2.15.0
|
cloud.google.com/go/dataproc/v2 v2.15.0
|
||||||
cloud.google.com/go/firestore v1.20.0
|
cloud.google.com/go/firestore v1.20.0
|
||||||
cloud.google.com/go/geminidataanalytics v0.3.0
|
cloud.google.com/go/geminidataanalytics v0.5.0
|
||||||
cloud.google.com/go/longrunning v0.7.0
|
cloud.google.com/go/longrunning v0.7.0
|
||||||
cloud.google.com/go/spanner v1.86.1
|
cloud.google.com/go/spanner v1.86.1
|
||||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
||||||
@@ -63,6 +63,7 @@ require (
|
|||||||
google.golang.org/api v0.256.0
|
google.golang.org/api v0.256.0
|
||||||
google.golang.org/genai v1.37.0
|
google.golang.org/genai v1.37.0
|
||||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
|
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
|
||||||
|
google.golang.org/grpc v1.76.0
|
||||||
google.golang.org/protobuf v1.36.10
|
google.golang.org/protobuf v1.36.10
|
||||||
modernc.org/sqlite v1.40.0
|
modernc.org/sqlite v1.40.0
|
||||||
)
|
)
|
||||||
@@ -229,7 +230,6 @@ require (
|
|||||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||||
google.golang.org/grpc v1.76.0 // indirect
|
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||||
modernc.org/libc v1.66.10 // indirect
|
modernc.org/libc v1.66.10 // indirect
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
|
|||||||
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
|
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
|
||||||
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
|
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
|
||||||
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
|
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
|
||||||
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
|
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
|
||||||
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
||||||
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
|
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
|
||||||
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
|
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
|
||||||
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
|
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
|
||||||
|
|||||||
@@ -14,23 +14,23 @@
|
|||||||
package cloudgda
|
package cloudgda
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
|
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
|
||||||
|
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util"
|
"github.com/googleapis/genai-toolbox/internal/util"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"google.golang.org/api/option"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SourceKind string = "cloud-gemini-data-analytics"
|
const SourceKind string = "cloud-gemini-data-analytics"
|
||||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
|
||||||
|
// NewDataChatClient can be overridden for testing.
|
||||||
|
var NewDataChatClient = geminidataanalytics.NewDataChatClient
|
||||||
|
|
||||||
// validate interface
|
// validate interface
|
||||||
var _ sources.SourceConfig = Config{}
|
var _ sources.SourceConfig = Config{}
|
||||||
@@ -67,29 +67,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
|||||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var client *http.Client
|
|
||||||
if r.UseClientOAuth {
|
|
||||||
client = &http.Client{
|
|
||||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Use Application Default Credentials
|
|
||||||
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
|
|
||||||
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
|
||||||
}
|
|
||||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
|
||||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
|
||||||
client = baseClient
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &Source{
|
s := &Source{
|
||||||
Config: r,
|
Config: r,
|
||||||
Client: client,
|
|
||||||
BaseURL: Endpoint,
|
|
||||||
userAgent: ua,
|
userAgent: ua,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !r.UseClientOAuth {
|
||||||
|
client, err := NewDataChatClient(ctx, option.WithUserAgent(ua))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
|
||||||
|
}
|
||||||
|
s.Client = client
|
||||||
|
}
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,8 +87,7 @@ var _ sources.Source = &Source{}
|
|||||||
|
|
||||||
type Source struct {
|
type Source struct {
|
||||||
Config
|
Config
|
||||||
Client *http.Client
|
Client *geminidataanalytics.DataChatClient
|
||||||
BaseURL string
|
|
||||||
userAgent string
|
userAgent string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,63 +103,36 @@ func (s *Source) GetProjectID() string {
|
|||||||
return s.ProjectID
|
return s.ProjectID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Source) GetBaseURL() string {
|
|
||||||
return s.BaseURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
|
||||||
if s.UseClientOAuth {
|
|
||||||
if accessToken == "" {
|
|
||||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
|
||||||
}
|
|
||||||
token := &oauth2.Token{AccessToken: accessToken}
|
|
||||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
|
||||||
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
|
|
||||||
return baseClient, nil
|
|
||||||
}
|
|
||||||
return s.Client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Source) UseClientAuthorization() bool {
|
func (s *Source) UseClientAuthorization() bool {
|
||||||
return s.UseClientOAuth
|
return s.UseClientOAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
|
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
|
||||||
// The API endpoint itself always uses the "global" location.
|
if s.UseClientOAuth {
|
||||||
apiLocation := "global"
|
if tokenStr == "" {
|
||||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
|
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
|
}
|
||||||
|
token := &oauth2.Token{AccessToken: tokenStr}
|
||||||
|
opts := []option.ClientOption{
|
||||||
|
option.WithUserAgent(s.userAgent),
|
||||||
|
option.WithTokenSource(oauth2.StaticTokenSource(token)),
|
||||||
|
}
|
||||||
|
|
||||||
client, err := s.GetClient(ctx, tokenStr)
|
client, err := NewDataChatClient(ctx, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
|
||||||
|
}
|
||||||
|
return client, func() { client.Close() }, nil
|
||||||
}
|
}
|
||||||
|
return s.Client, func() {}, nil
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
|
}
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
|
||||||
}
|
client, cleanup, err := s.GetClient(ctx, tokenStr)
|
||||||
req.Header.Set("Content-Type", "application/json")
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
resp, err := client.Do(req)
|
}
|
||||||
if err != nil {
|
defer cleanup()
|
||||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
|
||||||
}
|
return client.QueryData(ctx, req)
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
|
|
||||||
}
|
|
||||||
|
|
||||||
var result map[string]any
|
|
||||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -181,11 +181,9 @@ func TestInitialize(t *testing.T) {
|
|||||||
if gdaSrc.Client == nil && !tc.wantClientOAuth {
|
if gdaSrc.Client == nil && !tc.wantClientOAuth {
|
||||||
t.Fatal("expected non-nil HTTP client for ADC, got nil")
|
t.Fatal("expected non-nil HTTP client for ADC, got nil")
|
||||||
}
|
}
|
||||||
// When client OAuth is true, the source's client should be initialized with a base HTTP client
|
// When client OAuth is true, the source's client should be nil.
|
||||||
// that includes the user agent round tripper, but not the OAuth token. The token-aware
|
if gdaSrc.Client != nil && tc.wantClientOAuth {
|
||||||
// client is created by GetClient.
|
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
|
||||||
if gdaSrc.Client == nil && tc.wantClientOAuth {
|
|
||||||
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test UseClientAuthorization method
|
// Test UseClientAuthorization method
|
||||||
@@ -195,15 +193,16 @@ func TestInitialize(t *testing.T) {
|
|||||||
|
|
||||||
// Test GetClient with accessToken for client OAuth scenarios
|
// Test GetClient with accessToken for client OAuth scenarios
|
||||||
if tc.wantClientOAuth {
|
if tc.wantClientOAuth {
|
||||||
client, err := gdaSrc.GetClient(ctx, "dummy-token")
|
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetClient with token failed: %v", err)
|
t.Fatalf("GetClient with token failed: %v", err)
|
||||||
}
|
}
|
||||||
|
defer cleanup()
|
||||||
if client == nil {
|
if client == nil {
|
||||||
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
|
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
|
||||||
}
|
}
|
||||||
// Ensure passing empty token with UseClientOAuth enabled returns error
|
// Ensure passing empty token with UseClientOAuth enabled returns error
|
||||||
_, err = gdaSrc.GetClient(ctx, "")
|
_, _, err = gdaSrc.GetClient(ctx, "")
|
||||||
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
|
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
|
||||||
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
|
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||||
"github.com/goccy/go-yaml"
|
"github.com/goccy/go-yaml"
|
||||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const kind string = "cloud-gemini-data-analytics-query"
|
const kind string = "cloud-gemini-data-analytics-query"
|
||||||
@@ -60,7 +62,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
|||||||
type compatibleSource interface {
|
type compatibleSource interface {
|
||||||
GetProjectID() string
|
GetProjectID() string
|
||||||
UseClientAuthorization() bool
|
UseClientAuthorization() bool
|
||||||
RunQuery(context.Context, string, []byte) (any, error)
|
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
|
||||||
|
type QueryDataContext struct {
|
||||||
|
*geminidataanalyticspb.QueryDataContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
|
||||||
|
var raw map[string]any
|
||||||
|
if err := yaml.Unmarshal(b, &raw); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal context from yaml: %w", err)
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal context map: %w", err)
|
||||||
|
}
|
||||||
|
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
|
||||||
|
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
|
||||||
|
type GenerationOptions struct {
|
||||||
|
*geminidataanalyticspb.GenerationOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
|
||||||
|
var raw map[string]any
|
||||||
|
if err := yaml.Unmarshal(b, &raw); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal generation options from yaml: %w", err)
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(raw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal generation options map: %w", err)
|
||||||
|
}
|
||||||
|
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
|
||||||
|
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -97,12 +141,14 @@ func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error)
|
|||||||
}
|
}
|
||||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||||
|
|
||||||
return Tool{
|
t := Tool{
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
AllParams: allParameters,
|
AllParams: allParameters,
|
||||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||||
mcpManifest: mcpManifest,
|
mcpManifest: mcpManifest,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate interface
|
// validate interface
|
||||||
@@ -145,18 +191,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
|||||||
// The parent in the request payload uses the tool's configured location.
|
// The parent in the request payload uses the tool's configured location.
|
||||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
||||||
|
|
||||||
payload := &QueryDataRequest{
|
req := &geminidataanalyticspb.QueryDataRequest{
|
||||||
Parent: payloadParent,
|
Parent: payloadParent,
|
||||||
Prompt: query,
|
Prompt: query,
|
||||||
Context: t.Context,
|
|
||||||
GenerationOptions: t.GenerationOptions,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bodyBytes, err := json.Marshal(payload)
|
if t.Context != nil {
|
||||||
if err != nil {
|
req.Context = t.Context.QueryDataContext
|
||||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
|
||||||
}
|
}
|
||||||
return source.RunQuery(ctx, tokenStr, bodyBytes)
|
|
||||||
|
if t.GenerationOptions != nil {
|
||||||
|
req.GenerationOptions = t.GenerationOptions.GenerationOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
return source.RunQuery(ctx, tokenStr, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||||
|
|||||||
@@ -16,19 +16,16 @@ package cloudgda_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||||
yaml "github.com/goccy/go-yaml"
|
yaml "github.com/goccy/go-yaml"
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server"
|
"github.com/googleapis/genai-toolbox/internal/server"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||||
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||||
@@ -74,23 +71,29 @@ func TestParseFromYaml(t *testing.T) {
|
|||||||
Location: "us-central1",
|
Location: "us-central1",
|
||||||
AuthRequired: []string{},
|
AuthRequired: []string{},
|
||||||
Context: &cloudgdatool.QueryDataContext{
|
Context: &cloudgdatool.QueryDataContext{
|
||||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
|
||||||
SpannerReference: &cloudgdatool.SpannerReference{
|
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
|
||||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
|
||||||
ProjectID: "cloud-db-nl2sql",
|
SpannerReference: &geminidataanalyticspb.SpannerReference{
|
||||||
Region: "us-central1",
|
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
|
||||||
InstanceID: "evalbench",
|
ProjectId: "cloud-db-nl2sql",
|
||||||
DatabaseID: "financial",
|
Region: "us-central1",
|
||||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
InstanceId: "evalbench",
|
||||||
},
|
DatabaseId: "financial",
|
||||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
|
||||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
},
|
||||||
|
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
|
||||||
|
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||||
GenerateQueryResult: true,
|
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
|
||||||
|
GenerateQueryResult: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -108,68 +111,63 @@ func TestParseFromYaml(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to unmarshal: %s", err)
|
t.Fatalf("unable to unmarshal: %s", err)
|
||||||
}
|
}
|
||||||
if !cmp.Equal(tc.want, got.Tools) {
|
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
|
||||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
|
// fakeSource implements the compatibleSource interface for testing.
|
||||||
type authRoundTripper struct {
|
type fakeSource struct {
|
||||||
Token string
|
projectID string
|
||||||
Next http.RoundTripper
|
useClientOAuth bool
|
||||||
|
expectedQuery string
|
||||||
|
expectedParent string
|
||||||
|
response *geminidataanalyticspb.QueryDataResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
func (f *fakeSource) GetProjectID() string {
|
||||||
newReq := *req
|
return f.projectID
|
||||||
newReq.Header = make(http.Header)
|
|
||||||
for k, v := range req.Header {
|
|
||||||
newReq.Header[k] = v
|
|
||||||
}
|
|
||||||
newReq.Header.Set("Authorization", rt.Token)
|
|
||||||
if rt.Next == nil {
|
|
||||||
return http.DefaultTransport.RoundTrip(&newReq)
|
|
||||||
}
|
|
||||||
return rt.Next.RoundTrip(&newReq)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockSource struct {
|
func (f *fakeSource) UseClientAuthorization() bool {
|
||||||
kind string
|
return f.useClientOAuth
|
||||||
client *http.Client // Can be used to inject a specific client
|
|
||||||
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
|
|
||||||
config cloudgdasrc.Config // to return from ToConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSource) SourceKind() string { return m.kind }
|
func (f *fakeSource) SourceKind() string {
|
||||||
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
|
return "fake-gda-source"
|
||||||
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
|
|
||||||
if m.client != nil {
|
|
||||||
return m.client, nil
|
|
||||||
}
|
|
||||||
// Default client for testing if not explicitly set
|
|
||||||
transport := &http.Transport{}
|
|
||||||
authTransport := &authRoundTripper{
|
|
||||||
Token: "Bearer test-access-token", // Dummy token
|
|
||||||
Next: transport,
|
|
||||||
}
|
|
||||||
return &http.Client{Transport: authTransport}, nil
|
|
||||||
}
|
}
|
||||||
func (m *mockSource) UseClientAuthorization() bool { return false }
|
|
||||||
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
func (f *fakeSource) ToConfig() sources.SourceConfig {
|
||||||
return m, nil
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
|
||||||
|
if req.Prompt != f.expectedQuery {
|
||||||
|
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
|
||||||
|
}
|
||||||
|
if req.Parent != f.expectedParent {
|
||||||
|
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
|
||||||
|
}
|
||||||
|
// Basic validation of context/options could be added here if needed,
|
||||||
|
// but the test case mainly checks if they are passed correctly via successful invocation.
|
||||||
|
|
||||||
|
return f.response, nil
|
||||||
}
|
}
|
||||||
func (m *mockSource) BaseURL() string { return m.baseURL }
|
|
||||||
|
|
||||||
func TestInitialize(t *testing.T) {
|
func TestInitialize(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
// Minimal fake source
|
||||||
|
fake := &fakeSource{projectID: "test-project"}
|
||||||
|
|
||||||
srcs := map[string]sources.Source{
|
srcs := map[string]sources.Source{
|
||||||
"gda-api-source": &cloudgdasrc.Source{
|
"gda-api-source": fake,
|
||||||
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
|
||||||
Client: &http.Client{},
|
|
||||||
BaseURL: cloudgdasrc.Endpoint,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tcs := []struct {
|
tcs := []struct {
|
||||||
@@ -188,9 +186,6 @@ func TestInitialize(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add an incompatible source for testing
|
|
||||||
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
|
|
||||||
|
|
||||||
for _, tc := range tcs {
|
for _, tc := range tcs {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
@@ -207,92 +202,27 @@ func TestInitialize(t *testing.T) {
|
|||||||
|
|
||||||
func TestInvoke(t *testing.T) {
|
func TestInvoke(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
// Mock the HTTP client and server for Invoke testing
|
|
||||||
serverMux := http.NewServeMux()
|
|
||||||
// Update expected URL path to include the location "us-central1"
|
|
||||||
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
t.Errorf("expected POST method, got %s", r.Method)
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.Header.Get("Content-Type") != "application/json" {
|
|
||||||
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and unmarshal the request body
|
projectID := "test-project"
|
||||||
bodyBytes, err := io.ReadAll(r.Body)
|
location := "us-central1"
|
||||||
if err != nil {
|
query := "How many accounts who have region in Prague are eligible for loans?"
|
||||||
t.Errorf("failed to read request body: %v", err)
|
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var reqPayload cloudgdatool.QueryDataRequest
|
|
||||||
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
|
|
||||||
t.Errorf("failed to unmarshal request payload: %v", err)
|
|
||||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify expected fields
|
// Prepare expected response
|
||||||
if r.Header.Get("Authorization") == "" {
|
expectedResp := &geminidataanalyticspb.QueryDataResponse{
|
||||||
t.Errorf("expected Authorization header, got empty")
|
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
|
||||||
return
|
|
||||||
}
|
|
||||||
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
|
|
||||||
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify payload's parent uses the tool's configured location
|
|
||||||
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
|
|
||||||
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify context from config
|
|
||||||
if reqPayload.Context == nil ||
|
|
||||||
reqPayload.Context.DatasourceReferences == nil ||
|
|
||||||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
|
|
||||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
|
|
||||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
|
|
||||||
t.Errorf("unexpected context: %v", reqPayload.Context)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify generation options from config
|
|
||||||
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
|
|
||||||
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate a successful response
|
|
||||||
resp := map[string]any{
|
|
||||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
|
||||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
|
||||||
}
|
|
||||||
_ = json.NewEncoder(w).Encode(resp)
|
|
||||||
})
|
|
||||||
|
|
||||||
mockServer := httptest.NewServer(serverMux)
|
|
||||||
defer mockServer.Close()
|
|
||||||
|
|
||||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
|
||||||
|
|
||||||
// Create an authenticated client that uses the mock server
|
|
||||||
authTransport := &authRoundTripper{
|
|
||||||
Token: "Bearer test-access-token",
|
|
||||||
Next: mockServer.Client().Transport,
|
|
||||||
}
|
}
|
||||||
authClient := &http.Client{Transport: authTransport}
|
|
||||||
|
|
||||||
// Create a real cloudgdasrc.Source but inject the authenticated client
|
fake := &fakeSource{
|
||||||
mockGdaSource := &cloudgdasrc.Source{
|
projectID: projectID,
|
||||||
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
expectedQuery: query,
|
||||||
Client: authClient,
|
expectedParent: expectedParent,
|
||||||
BaseURL: mockServer.URL,
|
response: expectedResp,
|
||||||
}
|
}
|
||||||
|
|
||||||
srcs := map[string]sources.Source{
|
srcs := map[string]sources.Source{
|
||||||
"mock-gda-source": mockGdaSource,
|
"mock-gda-source": fake,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the tool config with context
|
// Initialize the tool config with context
|
||||||
@@ -301,25 +231,31 @@ func TestInvoke(t *testing.T) {
|
|||||||
Kind: "cloud-gemini-data-analytics-query",
|
Kind: "cloud-gemini-data-analytics-query",
|
||||||
Source: "mock-gda-source",
|
Source: "mock-gda-source",
|
||||||
Description: "Query Gemini Data Analytics",
|
Description: "Query Gemini Data Analytics",
|
||||||
Location: "us-central1", // Set location for the test
|
Location: location,
|
||||||
Context: &cloudgdatool.QueryDataContext{
|
Context: &cloudgdatool.QueryDataContext{
|
||||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
|
||||||
SpannerReference: &cloudgdatool.SpannerReference{
|
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
|
||||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
|
||||||
ProjectID: "cloud-db-nl2sql",
|
SpannerReference: &geminidataanalyticspb.SpannerReference{
|
||||||
Region: "us-central1",
|
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
|
||||||
InstanceID: "evalbench",
|
ProjectId: "cloud-db-nl2sql",
|
||||||
DatabaseID: "financial",
|
Region: "us-central1",
|
||||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
InstanceId: "evalbench",
|
||||||
},
|
DatabaseId: "financial",
|
||||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
|
||||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
},
|
||||||
|
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
|
||||||
|
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||||
GenerateQueryResult: true,
|
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
|
||||||
|
GenerateQueryResult: true,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,24 +266,25 @@ func TestInvoke(t *testing.T) {
|
|||||||
|
|
||||||
// Prepare parameters for invocation - ONLY query
|
// Prepare parameters for invocation - ONLY query
|
||||||
params := parameters.ParamValues{
|
params := parameters.ParamValues{
|
||||||
{Name: "query", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
{Name: "query", Value: query},
|
||||||
}
|
}
|
||||||
|
|
||||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||||
|
|
||||||
|
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||||
|
|
||||||
// Invoke the tool
|
// Invoke the tool
|
||||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
result, err := tool.Invoke(ctx, resourceMgr, params, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("tool invocation failed: %v", err)
|
t.Fatalf("tool invocation failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the result
|
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
|
||||||
expectedResult := map[string]any{
|
if !ok {
|
||||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
|
||||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !cmp.Equal(expectedResult, result) {
|
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
|
||||||
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
|
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,116 +0,0 @@
|
|||||||
// Copyright 2025 Google LLC
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package cloudgda
|
|
||||||
|
|
||||||
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
|
|
||||||
|
|
||||||
// QueryDataRequest represents the JSON body for the queryData API
|
|
||||||
type QueryDataRequest struct {
|
|
||||||
Parent string `json:"parent"`
|
|
||||||
Prompt string `json:"prompt"`
|
|
||||||
Context *QueryDataContext `json:"context,omitempty"`
|
|
||||||
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryDataContext reflects the proto definition for the query context.
|
|
||||||
type QueryDataContext struct {
|
|
||||||
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
|
|
||||||
type DatasourceReferences struct {
|
|
||||||
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
|
|
||||||
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
|
|
||||||
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpannerReference reflects the proto definition for Spanner database reference.
|
|
||||||
type SpannerReference struct {
|
|
||||||
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
|
||||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
|
|
||||||
type SpannerDatabaseReference struct {
|
|
||||||
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
|
||||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
|
||||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
|
||||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
|
||||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
|
||||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SpannerEngine represents the engine of the Spanner instance.
|
|
||||||
type SpannerEngine string
|
|
||||||
|
|
||||||
const (
|
|
||||||
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
|
|
||||||
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
|
|
||||||
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
|
|
||||||
type AlloyDBReference struct {
|
|
||||||
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
|
||||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
|
|
||||||
type AlloyDBDatabaseReference struct {
|
|
||||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
|
||||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
|
||||||
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
|
|
||||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
|
||||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
|
||||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
|
|
||||||
type CloudSQLReference struct {
|
|
||||||
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
|
||||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
|
|
||||||
type CloudSQLDatabaseReference struct {
|
|
||||||
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
|
||||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
|
||||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
|
||||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
|
||||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
|
||||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloudSQLEngine represents the engine of the Cloud SQL instance.
|
|
||||||
type CloudSQLEngine string
|
|
||||||
|
|
||||||
const (
|
|
||||||
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
|
|
||||||
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
|
|
||||||
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AgentContextReference reflects the proto definition for agent context.
|
|
||||||
type AgentContextReference struct {
|
|
||||||
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerationOptions reflects the proto definition for generation options.
|
|
||||||
type GenerationOptions struct {
|
|
||||||
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
|
|
||||||
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
|
|
||||||
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
|
|
||||||
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
|
|
||||||
}
|
|
||||||
@@ -18,78 +18,75 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
|
||||||
|
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||||
|
source "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
"github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||||
"github.com/googleapis/genai-toolbox/tests"
|
"github.com/googleapis/genai-toolbox/tests"
|
||||||
|
"google.golang.org/api/option"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
cloudGdaToolKind = "cloud-gemini-data-analytics-query"
|
cloudGdaToolKind = "cloud-gemini-data-analytics-query"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cloudGdaTransport struct {
|
type mockDataChatServer struct {
|
||||||
transport http.RoundTripper
|
geminidataanalyticspb.UnimplementedDataChatServiceServer
|
||||||
url *url.URL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") {
|
|
||||||
req.URL.Scheme = t.url.Scheme
|
|
||||||
req.URL.Host = t.url.Host
|
|
||||||
}
|
|
||||||
return t.transport.RoundTrip(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
type masterHandler struct {
|
|
||||||
t *testing.T
|
t *testing.T
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *mockDataChatServer) QueryData(ctx context.Context, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
|
||||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
if req.Prompt == "" {
|
||||||
h.t.Errorf("User-Agent header not found")
|
s.t.Errorf("missing prompt")
|
||||||
|
return nil, fmt.Errorf("missing prompt")
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.Method != http.MethodPost {
|
return &geminidataanalyticspb.QueryDataResponse{
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
GeneratedQuery: "SELECT * FROM table;",
|
||||||
return
|
NaturalLanguageAnswer: "Here is the answer.",
|
||||||
}
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Verify URL structure
|
func getCloudGdaToolsConfig() map[string]any {
|
||||||
// Expected: /v1beta/projects/{project}/locations/global:queryData
|
return map[string]any{
|
||||||
if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") {
|
"sources": map[string]any{
|
||||||
h.t.Errorf("unexpected URL path: %s", r.URL.Path)
|
"my-gda-source": map[string]any{
|
||||||
http.Error(w, "Not found", http.StatusNotFound)
|
"kind": "cloud-gemini-data-analytics",
|
||||||
return
|
"projectId": "test-project",
|
||||||
}
|
},
|
||||||
|
},
|
||||||
var reqBody cloudgda.QueryDataRequest
|
"tools": map[string]any{
|
||||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
"cloud-gda-query": map[string]any{
|
||||||
h.t.Fatalf("failed to decode request body: %v", err)
|
"kind": cloudGdaToolKind,
|
||||||
}
|
"source": "my-gda-source",
|
||||||
|
"description": "Test GDA Tool",
|
||||||
if reqBody.Prompt == "" {
|
"location": "us-central1",
|
||||||
http.Error(w, "missing prompt", http.StatusBadRequest)
|
"context": map[string]any{
|
||||||
return
|
"datasourceReferences": map[string]any{
|
||||||
}
|
"spannerReference": map[string]any{
|
||||||
|
"databaseReference": map[string]any{
|
||||||
response := map[string]any{
|
"projectId": "test-project",
|
||||||
"queryResult": "SELECT * FROM table;",
|
"instanceId": "test-instance",
|
||||||
"naturalLanguageAnswer": "Here is the answer.",
|
"databaseId": "test-db",
|
||||||
}
|
"engine": "GOOGLE_SQL",
|
||||||
|
},
|
||||||
w.Header().Set("Content-Type", "application/json")
|
},
|
||||||
w.WriteHeader(http.StatusOK)
|
},
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
},
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,26 +94,37 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
handler := &masterHandler{t: t}
|
// Start a gRPC server
|
||||||
server := httptest.NewServer(handler)
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
defer server.Close()
|
|
||||||
|
|
||||||
serverURL, err := url.Parse(server.URL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to parse server URL: %v", err)
|
t.Fatalf("failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
|
s := grpc.NewServer()
|
||||||
|
geminidataanalyticspb.RegisterDataChatServiceServer(s, &mockDataChatServer{t: t})
|
||||||
|
go func() {
|
||||||
|
if err := s.Serve(lis); err != nil {
|
||||||
|
// This might happen on strict shutdown, log if unexpected
|
||||||
|
t.Logf("server executed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer s.Stop()
|
||||||
|
|
||||||
originalTransport := http.DefaultClient.Transport
|
// Configure toolbox to use the gRPC server
|
||||||
if originalTransport == nil {
|
endpoint := lis.Addr().String()
|
||||||
originalTransport = http.DefaultTransport
|
|
||||||
|
// Override client creation
|
||||||
|
origFunc := source.NewDataChatClient
|
||||||
|
defer func() {
|
||||||
|
source.NewDataChatClient = origFunc
|
||||||
|
}()
|
||||||
|
|
||||||
|
source.NewDataChatClient = func(ctx context.Context, opts ...option.ClientOption) (*geminidataanalytics.DataChatClient, error) {
|
||||||
|
opts = append(opts,
|
||||||
|
option.WithEndpoint(endpoint),
|
||||||
|
option.WithoutAuthentication(),
|
||||||
|
option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())))
|
||||||
|
return origFunc(ctx, opts...)
|
||||||
}
|
}
|
||||||
http.DefaultClient.Transport = &cloudGdaTransport{
|
|
||||||
transport: originalTransport,
|
|
||||||
url: serverURL,
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
http.DefaultClient.Transport = originalTransport
|
|
||||||
})
|
|
||||||
|
|
||||||
var args []string
|
var args []string
|
||||||
toolsFile := getCloudGdaToolsConfig()
|
toolsFile := getCloudGdaToolsConfig()
|
||||||
@@ -156,7 +164,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
|||||||
|
|
||||||
// 2. RunToolInvokeParametersTest
|
// 2. RunToolInvokeParametersTest
|
||||||
params := []byte(`{"query": "test question"}`)
|
params := []byte(`{"query": "test question"}`)
|
||||||
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
|
tests.RunToolInvokeParametersTest(t, toolName, params, "\"generated_query\":\"SELECT * FROM table;\"")
|
||||||
|
|
||||||
// 3. Manual MCP Tool Call Test
|
// 3. Manual MCP Tool Call Test
|
||||||
// Initialize MCP session
|
// Initialize MCP session
|
||||||
@@ -196,38 +204,3 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
|||||||
t.Errorf("MCP response does not contain expected query result: %s", respStr)
|
t.Errorf("MCP response does not contain expected query result: %s", respStr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCloudGdaToolsConfig() map[string]any {
|
|
||||||
// Mocked responses and a dummy `projectId` are used in this integration
|
|
||||||
// test due to limited project-specific allowlisting. API functionality is
|
|
||||||
// verified via internal monitoring; this test specifically validates the
|
|
||||||
// integration flow between the source and the tool.
|
|
||||||
return map[string]any{
|
|
||||||
"sources": map[string]any{
|
|
||||||
"my-gda-source": map[string]any{
|
|
||||||
"kind": "cloud-gemini-data-analytics",
|
|
||||||
"projectId": "test-project",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"tools": map[string]any{
|
|
||||||
"cloud-gda-query": map[string]any{
|
|
||||||
"kind": cloudGdaToolKind,
|
|
||||||
"source": "my-gda-source",
|
|
||||||
"description": "Test GDA Tool",
|
|
||||||
"location": "us-central1",
|
|
||||||
"context": map[string]any{
|
|
||||||
"datasourceReferences": map[string]any{
|
|
||||||
"spannerReference": map[string]any{
|
|
||||||
"databaseReference": map[string]any{
|
|
||||||
"projectId": "test-project",
|
|
||||||
"instanceId": "test-instance",
|
|
||||||
"databaseId": "test-db",
|
|
||||||
"engine": "GOOGLE_SQL",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user