mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-15 02:18:10 -05:00
Compare commits
4 Commits
registry-n
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
172dd6b19b | ||
|
|
9f5b04cf73 | ||
|
|
66d6b58c4f | ||
|
|
6b02591703 |
@@ -41,13 +41,13 @@ tools:
|
||||
|
||||
### Usage Flow
|
||||
|
||||
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
When using this tool, a `query` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
|
||||
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
||||
|
||||
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
||||
|
||||
**Example Input Prompt:**
|
||||
**Example Input Query:**
|
||||
|
||||
```text
|
||||
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
||||
|
||||
2
go.mod
2
go.mod
@@ -12,7 +12,7 @@ require (
|
||||
cloud.google.com/go/dataplex v1.28.0
|
||||
cloud.google.com/go/dataproc/v2 v2.15.0
|
||||
cloud.google.com/go/firestore v1.20.0
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0
|
||||
cloud.google.com/go/geminidataanalytics v0.5.0
|
||||
cloud.google.com/go/longrunning v0.7.0
|
||||
cloud.google.com/go/spanner v1.86.1
|
||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
||||
|
||||
4
go.sum
4
go.sum
@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
|
||||
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
|
||||
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
|
||||
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
||||
cloud.google.com/go/geminidataanalytics v0.5.0 h1:+1usY81Cb+hE8BokpqCM7EgJtRCKzUKx7FvrHbT5hCA=
|
||||
cloud.google.com/go/geminidataanalytics v0.5.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
||||
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
|
||||
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
|
||||
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
|
||||
|
||||
@@ -14,23 +14,20 @@
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
geminidataanalytics "cloud.google.com/go/geminidataanalytics/apiv1beta"
|
||||
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
@@ -67,29 +64,19 @@ func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.So
|
||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: util.NewUserAgentRoundTripper(ua, http.DefaultTransport),
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
|
||||
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(ua, baseClient.Transport)
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: client,
|
||||
BaseURL: Endpoint,
|
||||
userAgent: ua,
|
||||
}
|
||||
|
||||
if !r.UseClientOAuth {
|
||||
client, err := geminidataanalytics.NewDataChatClient(ctx, option.WithUserAgent(ua))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DataChatClient: %w", err)
|
||||
}
|
||||
s.Client = client
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -97,8 +84,7 @@ var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
BaseURL string
|
||||
Client *geminidataanalytics.DataChatClient
|
||||
userAgent string
|
||||
}
|
||||
|
||||
@@ -114,63 +100,34 @@ func (s *Source) GetProjectID() string {
|
||||
return s.ProjectID
|
||||
}
|
||||
|
||||
func (s *Source) GetBaseURL() string {
|
||||
return s.BaseURL
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
baseClient.Transport = util.NewUserAgentRoundTripper(s.userAgent, baseClient.Transport)
|
||||
return baseClient, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
|
||||
func (s *Source) RunQuery(ctx context.Context, tokenStr string, bodyBytes []byte) (any, error) {
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", s.GetProjectID(), apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", s.GetBaseURL(), apiParent)
|
||||
|
||||
client, err := s.GetClient(ctx, tokenStr)
|
||||
func (s *Source) RunQuery(ctx context.Context, tokenStr string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
|
||||
client, cleanup, err := s.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return client.QueryData(ctx, req)
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, tokenStr string) (*geminidataanalytics.DataChatClient, func(), error) {
|
||||
if s.UseClientOAuth {
|
||||
if tokenStr == "" {
|
||||
return nil, nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: tokenStr}
|
||||
client, err := geminidataanalytics.NewDataChatClient(ctx,
|
||||
option.WithUserAgent(s.userAgent),
|
||||
option.WithTokenSource(oauth2.StaticTokenSource(token)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create per-request DataChatClient: %w", err)
|
||||
}
|
||||
return client, func() { client.Close() }, nil
|
||||
}
|
||||
return s.Client, func() {}, nil
|
||||
}
|
||||
|
||||
@@ -181,11 +181,9 @@ func TestInitialize(t *testing.T) {
|
||||
if gdaSrc.Client == nil && !tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for ADC, got nil")
|
||||
}
|
||||
// When client OAuth is true, the source's client should be initialized with a base HTTP client
|
||||
// that includes the user agent round tripper, but not the OAuth token. The token-aware
|
||||
// client is created by GetClient.
|
||||
if gdaSrc.Client == nil && tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
|
||||
// When client OAuth is true, the source's client should be nil.
|
||||
if gdaSrc.Client != nil && tc.wantClientOAuth {
|
||||
t.Fatal("expected nil HTTP client for client OAuth config, got non-nil")
|
||||
}
|
||||
|
||||
// Test UseClientAuthorization method
|
||||
@@ -195,15 +193,16 @@ func TestInitialize(t *testing.T) {
|
||||
|
||||
// Test GetClient with accessToken for client OAuth scenarios
|
||||
if tc.wantClientOAuth {
|
||||
client, err := gdaSrc.GetClient(ctx, "dummy-token")
|
||||
client, cleanup, err := gdaSrc.GetClient(ctx, "dummy-token")
|
||||
if err != nil {
|
||||
t.Fatalf("GetClient with token failed: %v", err)
|
||||
}
|
||||
defer cleanup()
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
|
||||
}
|
||||
// Ensure passing empty token with UseClientOAuth enabled returns error
|
||||
_, err = gdaSrc.GetClient(ctx, "")
|
||||
_, _, err = gdaSrc.GetClient(ctx, "")
|
||||
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
|
||||
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
|
||||
}
|
||||
|
||||
@@ -19,15 +19,32 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
)
|
||||
|
||||
const kind string = "cloud-gemini-data-analytics-query"
|
||||
|
||||
// Guidance is the tool guidance string.
|
||||
const Guidance = `Tool guidance:
|
||||
Inputs:
|
||||
1. query: A natural language formulation of a database query.
|
||||
Outputs: (all optional)
|
||||
1. disambiguation_question: Clarification questions or comments where the tool needs the users' input.
|
||||
2. generated_query: The generated query for the user query.
|
||||
3. intent_explanation: An explanation for why the tool produced ` + "`generated_query`" + `.
|
||||
4. query_result: The result of executing ` + "`generated_query`" + `.
|
||||
5. natural_language_answer: The natural language answer that summarizes the ` + "`query`" + ` and ` + "`query_result`" + `.
|
||||
|
||||
Usage guidance:
|
||||
1. If ` + "`disambiguation_question`" + ` is produced, then solicit the needed inputs from the user and try the tool with a new ` + "`query`" + ` that has the needed clarification.
|
||||
2. If ` + "`natural_language_answer`" + ` is produced, use ` + "`intent_explanation`" + ` and ` + "`generated_query`" + ` to see if you need to clarify any assumptions for the user.`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
@@ -45,7 +62,49 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
|
||||
type compatibleSource interface {
|
||||
GetProjectID() string
|
||||
UseClientAuthorization() bool
|
||||
RunQuery(context.Context, string, []byte) (any, error)
|
||||
RunQuery(context.Context, string, *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error)
|
||||
}
|
||||
|
||||
// QueryDataContext wraps geminidataanalyticspb.QueryDataContext to support YAML decoding via protojson.
|
||||
type QueryDataContext struct {
|
||||
*geminidataanalyticspb.QueryDataContext
|
||||
}
|
||||
|
||||
func (q *QueryDataContext) UnmarshalYAML(b []byte) error {
|
||||
var raw map[string]any
|
||||
if err := yaml.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
jsonBytes, err := json.Marshal(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal context map: %w", err)
|
||||
}
|
||||
q.QueryDataContext = &geminidataanalyticspb.QueryDataContext{}
|
||||
if err := protojson.Unmarshal(jsonBytes, q.QueryDataContext); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal context to proto: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerationOptions wraps geminidataanalyticspb.GenerationOptions to support YAML decoding via protojson.
|
||||
type GenerationOptions struct {
|
||||
*geminidataanalyticspb.GenerationOptions
|
||||
}
|
||||
|
||||
func (g *GenerationOptions) UnmarshalYAML(b []byte) error {
|
||||
var raw map[string]any
|
||||
if err := yaml.Unmarshal(b, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
jsonBytes, err := json.Marshal(raw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal generation options map: %w", err)
|
||||
}
|
||||
g.GenerationOptions = &geminidataanalyticspb.GenerationOptions{}
|
||||
if err := protojson.Unmarshal(jsonBytes, g.GenerationOptions); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal generation options to proto: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -68,19 +127,28 @@ func (cfg Config) ToolConfigKind() string {
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
// The query is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
|
||||
parameters.NewStringParameterWithRequired("query", "A natural language formulation of a database query.", true),
|
||||
}
|
||||
// The input and outputs are for tool guidance, usage guidance is for multi-turn interaction.
|
||||
guidance := Guidance
|
||||
|
||||
if cfg.Description != "" {
|
||||
cfg.Description += "\n\n" + guidance
|
||||
} else {
|
||||
cfg.Description = guidance
|
||||
}
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
t := Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
@@ -105,9 +173,9 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
}
|
||||
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
query, ok := paramsMap["query"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("prompt parameter not found or not a string")
|
||||
return nil, fmt.Errorf("query parameter not found or not a string")
|
||||
}
|
||||
|
||||
// Parse the access token if provided
|
||||
@@ -123,18 +191,20 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", source.GetProjectID(), t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
Prompt: prompt,
|
||||
Context: t.Context,
|
||||
GenerationOptions: t.GenerationOptions,
|
||||
req := &geminidataanalyticspb.QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
Prompt: query,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||
if t.Context != nil {
|
||||
req.Context = t.Context.QueryDataContext
|
||||
}
|
||||
return source.RunQuery(ctx, tokenStr, bodyBytes)
|
||||
|
||||
if t.GenerationOptions != nil {
|
||||
req.GenerationOptions = t.GenerationOptions.GenerationOptions
|
||||
}
|
||||
|
||||
return source.RunQuery(ctx, tokenStr, req)
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
|
||||
@@ -16,19 +16,16 @@ package cloudgda_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"cloud.google.com/go/geminidataanalytics/apiv1beta/geminidataanalyticspb"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
@@ -74,23 +71,29 @@ func TestParseFromYaml(t *testing.T) {
|
||||
Location: "us-central1",
|
||||
AuthRequired: []string{},
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
|
||||
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
|
||||
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
|
||||
SpannerReference: &geminidataanalyticspb.SpannerReference{
|
||||
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
|
||||
ProjectId: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceId: "evalbench",
|
||||
DatabaseId: "financial",
|
||||
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
|
||||
},
|
||||
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
|
||||
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -108,68 +111,63 @@ func TestParseFromYaml(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Tools) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
||||
if !cmp.Equal(tc.want, got.Tools, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataContext{}, geminidataanalyticspb.DatasourceReferences{}, geminidataanalyticspb.SpannerReference{}, geminidataanalyticspb.SpannerDatabaseReference{}, geminidataanalyticspb.AgentContextReference{}, geminidataanalyticspb.GenerationOptions{}, geminidataanalyticspb.DatasourceReferences_SpannerReference{})) {
|
||||
t.Errorf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
|
||||
type authRoundTripper struct {
|
||||
Token string
|
||||
Next http.RoundTripper
|
||||
// fakeSource implements the compatibleSource interface for testing.
|
||||
type fakeSource struct {
|
||||
projectID string
|
||||
useClientOAuth bool
|
||||
expectedQuery string
|
||||
expectedParent string
|
||||
response *geminidataanalyticspb.QueryDataResponse
|
||||
}
|
||||
|
||||
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
newReq.Header.Set("Authorization", rt.Token)
|
||||
if rt.Next == nil {
|
||||
return http.DefaultTransport.RoundTrip(&newReq)
|
||||
}
|
||||
return rt.Next.RoundTrip(&newReq)
|
||||
func (f *fakeSource) GetProjectID() string {
|
||||
return f.projectID
|
||||
}
|
||||
|
||||
type mockSource struct {
|
||||
kind string
|
||||
client *http.Client // Can be used to inject a specific client
|
||||
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
|
||||
config cloudgdasrc.Config // to return from ToConfig
|
||||
func (f *fakeSource) UseClientAuthorization() bool {
|
||||
return f.useClientOAuth
|
||||
}
|
||||
|
||||
func (m *mockSource) SourceKind() string { return m.kind }
|
||||
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
|
||||
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
|
||||
if m.client != nil {
|
||||
return m.client, nil
|
||||
}
|
||||
// Default client for testing if not explicitly set
|
||||
transport := &http.Transport{}
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token", // Dummy token
|
||||
Next: transport,
|
||||
}
|
||||
return &http.Client{Transport: authTransport}, nil
|
||||
func (f *fakeSource) SourceKind() string {
|
||||
return "fake-gda-source"
|
||||
}
|
||||
func (m *mockSource) UseClientAuthorization() bool { return false }
|
||||
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
||||
return m, nil
|
||||
|
||||
func (f *fakeSource) ToConfig() sources.SourceConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *fakeSource) RunQuery(ctx context.Context, token string, req *geminidataanalyticspb.QueryDataRequest) (*geminidataanalyticspb.QueryDataResponse, error) {
|
||||
if req.Prompt != f.expectedQuery {
|
||||
return nil, fmt.Errorf("unexpected query: got %q, want %q", req.Prompt, f.expectedQuery)
|
||||
}
|
||||
if req.Parent != f.expectedParent {
|
||||
return nil, fmt.Errorf("unexpected parent: got %q, want %q", req.Parent, f.expectedParent)
|
||||
}
|
||||
// Basic validation of context/options could be added here if needed,
|
||||
// but the test case mainly checks if they are passed correctly via successful invocation.
|
||||
|
||||
return f.response, nil
|
||||
}
|
||||
func (m *mockSource) BaseURL() string { return m.baseURL }
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Minimal fake source
|
||||
fake := &fakeSource{projectID: "test-project"}
|
||||
|
||||
srcs := map[string]sources.Source{
|
||||
"gda-api-source": &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: &http.Client{},
|
||||
BaseURL: cloudgdasrc.Endpoint,
|
||||
},
|
||||
"gda-api-source": fake,
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
@@ -188,9 +186,6 @@ func TestInitialize(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Add an incompatible source for testing
|
||||
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
@@ -207,92 +202,27 @@ func TestInitialize(t *testing.T) {
|
||||
|
||||
func TestInvoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Mock the HTTP client and server for Invoke testing
|
||||
serverMux := http.NewServeMux()
|
||||
// Update expected URL path to include the location "us-central1"
|
||||
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST method, got %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read and unmarshal the request body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read request body: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var reqPayload cloudgdatool.QueryDataRequest
|
||||
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
|
||||
t.Errorf("failed to unmarshal request payload: %v", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
projectID := "test-project"
|
||||
location := "us-central1"
|
||||
query := "How many accounts who have region in Prague are eligible for loans?"
|
||||
expectedParent := fmt.Sprintf("projects/%s/locations/%s", projectID, location)
|
||||
|
||||
// Verify expected fields
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Errorf("expected Authorization header, got empty")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
|
||||
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
|
||||
}
|
||||
|
||||
// Verify payload's parent uses the tool's configured location
|
||||
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
|
||||
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
|
||||
}
|
||||
|
||||
// Verify context from config
|
||||
if reqPayload.Context == nil ||
|
||||
reqPayload.Context.DatasourceReferences == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
|
||||
t.Errorf("unexpected context: %v", reqPayload.Context)
|
||||
}
|
||||
|
||||
// Verify generation options from config
|
||||
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
|
||||
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
|
||||
}
|
||||
|
||||
// Simulate a successful response
|
||||
resp := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
mockServer := httptest.NewServer(serverMux)
|
||||
defer mockServer.Close()
|
||||
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
|
||||
// Create an authenticated client that uses the mock server
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token",
|
||||
Next: mockServer.Client().Transport,
|
||||
// Prepare expected response
|
||||
expectedResp := &geminidataanalyticspb.QueryDataResponse{
|
||||
GeneratedQuery: "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
NaturalLanguageAnswer: "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
authClient := &http.Client{Transport: authTransport}
|
||||
|
||||
// Create a real cloudgdasrc.Source but inject the authenticated client
|
||||
mockGdaSource := &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: authClient,
|
||||
BaseURL: mockServer.URL,
|
||||
fake := &fakeSource{
|
||||
projectID: projectID,
|
||||
expectedQuery: query,
|
||||
expectedParent: expectedParent,
|
||||
response: expectedResp,
|
||||
}
|
||||
|
||||
srcs := map[string]sources.Source{
|
||||
"mock-gda-source": mockGdaSource,
|
||||
"mock-gda-source": fake,
|
||||
}
|
||||
|
||||
// Initialize the tool config with context
|
||||
@@ -301,25 +231,31 @@ func TestInvoke(t *testing.T) {
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "mock-gda-source",
|
||||
Description: "Query Gemini Data Analytics",
|
||||
Location: "us-central1", // Set location for the test
|
||||
Location: location,
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
QueryDataContext: &geminidataanalyticspb.QueryDataContext{
|
||||
DatasourceReferences: &geminidataanalyticspb.DatasourceReferences{
|
||||
References: &geminidataanalyticspb.DatasourceReferences_SpannerReference{
|
||||
SpannerReference: &geminidataanalyticspb.SpannerReference{
|
||||
DatabaseReference: &geminidataanalyticspb.SpannerDatabaseReference{
|
||||
ProjectId: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceId: "evalbench",
|
||||
DatabaseId: "financial",
|
||||
Engine: geminidataanalyticspb.SpannerDatabaseReference_GOOGLE_SQL,
|
||||
},
|
||||
AgentContextReference: &geminidataanalyticspb.AgentContextReference{
|
||||
ContextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
GenerationOptions: &geminidataanalyticspb.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -328,26 +264,27 @@ func TestInvoke(t *testing.T) {
|
||||
t.Fatalf("failed to initialize tool: %v", err)
|
||||
}
|
||||
|
||||
// Prepare parameters for invocation - ONLY prompt
|
||||
// Prepare parameters for invocation - ONLY query
|
||||
params := parameters.ParamValues{
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
{Name: "query", Value: query},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "")
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
// Validate the result
|
||||
expectedResult := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
gotResp, ok := result.(*geminidataanalyticspb.QueryDataResponse)
|
||||
if !ok {
|
||||
t.Fatalf("expected result type *geminidataanalyticspb.QueryDataResponse, got %T", result)
|
||||
}
|
||||
|
||||
if !cmp.Equal(expectedResult, result) {
|
||||
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
|
||||
if diff := cmp.Diff(expectedResp, gotResp, cmpopts.IgnoreUnexported(geminidataanalyticspb.QueryDataResponse{})); diff != "" {
|
||||
t.Errorf("unexpected result mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda
|
||||
|
||||
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
|
||||
|
||||
// QueryDataRequest represents the JSON body for the queryData API
|
||||
type QueryDataRequest struct {
|
||||
Parent string `json:"parent"`
|
||||
Prompt string `json:"prompt"`
|
||||
Context *QueryDataContext `json:"context,omitempty"`
|
||||
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
|
||||
}
|
||||
|
||||
// QueryDataContext reflects the proto definition for the query context.
|
||||
type QueryDataContext struct {
|
||||
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
|
||||
}
|
||||
|
||||
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
|
||||
type DatasourceReferences struct {
|
||||
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
|
||||
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
|
||||
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerReference reflects the proto definition for Spanner database reference.
|
||||
type SpannerReference struct {
|
||||
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
|
||||
type SpannerDatabaseReference struct {
|
||||
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerEngine represents the engine of the Spanner instance.
|
||||
type SpannerEngine string
|
||||
|
||||
const (
|
||||
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
|
||||
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
|
||||
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
|
||||
)
|
||||
|
||||
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBReference struct {
|
||||
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBDatabaseReference struct {
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLReference struct {
|
||||
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLDatabaseReference struct {
|
||||
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLEngine represents the engine of the Cloud SQL instance.
|
||||
type CloudSQLEngine string
|
||||
|
||||
const (
|
||||
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
|
||||
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
|
||||
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
|
||||
)
|
||||
|
||||
// AgentContextReference reflects the proto definition for agent context.
|
||||
type AgentContextReference struct {
|
||||
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationOptions reflects the proto definition for generation options.
|
||||
type GenerationOptions struct {
|
||||
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
|
||||
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
|
||||
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
|
||||
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
|
||||
}
|
||||
@@ -139,12 +139,12 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
||||
// 1. RunToolGetTestByName
|
||||
expectedManifest := map[string]any{
|
||||
toolName: map[string]any{
|
||||
"description": "Test GDA Tool",
|
||||
"description": "Test GDA Tool\n\n" + cloudgda.Guidance,
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "prompt",
|
||||
"name": "query",
|
||||
"type": "string",
|
||||
"description": "The natural language question to ask.",
|
||||
"description": "A natural language formulation of a database query.",
|
||||
"required": true,
|
||||
"authSources": []any{},
|
||||
},
|
||||
@@ -155,7 +155,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
||||
tests.RunToolGetTestByName(t, toolName, expectedManifest)
|
||||
|
||||
// 2. RunToolInvokeParametersTest
|
||||
params := []byte(`{"prompt": "test question"}`)
|
||||
params := []byte(`{"query": "test question"}`)
|
||||
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
|
||||
|
||||
// 3. Manual MCP Tool Call Test
|
||||
@@ -172,7 +172,7 @@ func TestCloudGdaToolEndpoints(t *testing.T) {
|
||||
Params: map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": map[string]any{
|
||||
"prompt": "test question",
|
||||
"query": "test question",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user